Giter Club home page Giter Club logo

Comments (4)

yunjey avatar yunjey commented on July 28, 2024

I don't understand what's the problem. The data order in the outputs and targets are same because both of them used pack_padded_sequences. There is no problem.

Did you understand what I'm saying? If you don't understand, do not hesitate to ask another question.

from pytorch-tutorial.

hashbangCoder avatar hashbangCoder commented on July 28, 2024

Ok. Here's an example :

target = [[1,2,3,4], [5,6,0,0]]   #batch*time-step
model_out = torch.Tensor(2,4,vocab_size)

Now for the sequence loss, consider the first sequence & targets - model_out[0] & target[0] you're supposed to take CrossEntropyLoss (CEL) at every time step i.e .

sequence_loss[0] = CEL(model_out[0][0],target[0][0]) + ... + CEL(model_out[0][3],target[0][3])

Similarly for the second sequence. Net mini-batch loss = sum(sequence_loss)/2.0. Now, what pack_padded_sequences does is convert everything into a minibatch.

packed = pack_padded_sequence(embed_inp,[4,2],True)    #output of Embedding(10,10,0)

 1.3866 -1.4917  0.1638  0.3532  2.9881  0.0461  0.6481 -1.2511  2.1563 -1.0258
-1.1634 -0.1845  0.0265  1.3146 -1.2236 -1.3518  1.2663 -0.1945  0.0213 -0.1681
-0.1127 -1.6191 -0.3237 -1.4512  1.0955 -0.8381  1.2835  0.5311  1.2359 -1.5673
 0.3134 -0.3416 -1.2120 -1.7240  1.3178 -0.2555  0.2030  0.8582 -0.5273 -0.0708
 0.6825 -0.3805 -1.1349  0.6390 -1.5805 -2.4762 -0.6307 -0.2127  0.8466 -0.2337
 0.4657 -1.5292  0.7106  1.4591 -0.7695 -0.5821 -2.2042  0.1225 -0.4959 -0.2382
[torch.FloatTensor of size 6x10]


model_out = lstm(packed)[0]     #hidden states from nn.LSTM(10,20)

Columns 0 to 9 
 0.0251 -0.2013 -0.0108  0.0996  0.1172  0.0813  0.1368  0.0036 -0.0472 -0.0974
-0.0651  0.0791  0.0676 -0.0195 -0.0979  0.0400  0.0426  0.0095 -0.0260 -0.0737
-0.0977 -0.0902  0.0275  0.1105  0.1506  0.1077  0.1157  0.1418 -0.0823 -0.2376
-0.1670 -0.0537  0.0964  0.0857 -0.0417  0.0553  0.0674  0.1317 -0.0238 -0.1653
-0.1909 -0.0686  0.2268 -0.1196 -0.0300 -0.0605  0.1941  0.1778 -0.0816 -0.0824
-0.0367 -0.0666  0.1007 -0.2088 -0.1325 -0.1464  0.0981  0.2038 -0.0758  0.0281

Columns 10 to 19 
 0.2482 -0.0968 -0.0109  0.0505 -0.1285  0.0049  0.2775 -0.1530 -0.1499  0.1109
 0.0191 -0.0319 -0.0808 -0.1486 -0.0136  0.0800 -0.0394 -0.0267  0.0872  0.0202
 0.2717 -0.0678 -0.0742  0.1615 -0.0917 -0.0033  0.1297 -0.2521 -0.2597  0.1603
-0.1123 -0.0243 -0.1594 -0.0173  0.0872  0.0040 -0.1204 -0.1173  0.1110  0.0384
 0.1221 -0.1210 -0.1716  0.1877 -0.1433  0.0564  0.0515 -0.1887  0.0390  0.0576
 0.1573 -0.1564  0.1000  0.2310 -0.1333  0.0478  0.0689 -0.0405  0.0795 -0.0070
[torch.FloatTensor of size 6x20]
, batch_sizes=[2, 2, 1, 1])

When you apply CrossEntropyLoss on this and pack_padded_sequence(targets), you're just averaging the loss over all sequences right? Without the summation.

FWIW, I dont think there's much difference, just that the gradients will change by a constant factor, which may speedup/slowdown training

from pytorch-tutorial.

yunjey avatar yunjey commented on July 28, 2024

It is common for me to divide the loss by the sequence length. As you mentioned above, this is likely to have a small impact on the performance.

from pytorch-tutorial.

yunjey avatar yunjey commented on July 28, 2024

@hashbangCoder It will be helpful for you to see im2txt.

from pytorch-tutorial.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.