Giter Club home page Giter Club logo

Comments (9)

tonyhqanguyen avatar tonyhqanguyen commented on June 21, 2024 1

Hmm.... I'm not sure what I just changed but the loss seems pretty reasonable now, it's improving to about ~4.6 so hopefully there's an improvement. Thank you so much!

from transformer-pytorch.

tunz avatar tunz commented on June 21, 2024

Hi! Since I just reimplemented the algorithm in pytorch, it would be better to ask it to tensor2tensor author if you want the clear answer. But, let me try to explain what I understood at the time of writing the code.

First, this is not the "normalized cross entropy". tensor2tensor code named this cross entropy function as smoothing_cross_entropy. I also read the NCE paper roughly, and looks like this is not related to the paper. this is just using "normalizing" constant for readability.

AFAIK, label smoothing came from an intuition that training data might have wrong labels. Large training data set usually contains quite a lot of misclassified data. So, it just gives some small confidence value even to incorrect labels so that the model does not ignore the actual label of mislabeled data in training dataset.

Then, a small problem is that cross entropy loss value becomes large. Even when the model is 100% accurate, the loss is not zero because of the label smoothing. So, we just subtract the "normalizing" constant value from the cross entropy value. Then, loss will be close to zero as the model becomes accurate. This does not affect to backward propagation, but it just make it clear to debug if the loss gets stuck or converged toward an optimal point.

from transformer-pytorch.

tonyhqanguyen avatar tonyhqanguyen commented on June 21, 2024

Ah I see. Thank you very much!

from transformer-pytorch.

tonyhqanguyen avatar tonyhqanguyen commented on June 21, 2024

Hi sorry to bother you again, I'm just making sure here.

The argument pred should have shape (batch size * max sequence length, vocab size) and the argument ans should have shape (batch size * sequence length,) right? Should the argument pred be pure logits computed by the Transformer model?

The problem I'm having here is that when you compute log_softmax of the inputs, the values have a pretty significant negative value, around -11, so when I add each row of log probabilities by calling .sum(dim=1) on the logged probabilities, I get around 3000 for the first few iterations. You said we subtract the normalizing constant value from the cross entropy value, but the normalizing constant, as I see here while debugging, is so small compared to the cross entropy value. The normalizing constant is < 1, and the cross entropy is 3000.

from transformer-pytorch.

tunz avatar tunz commented on June 21, 2024

Yes, that's right.

I'm not sure what's happening there. it just could be normal unless it's not converged. But, one weird thing is the normalizing constant has to be around 12 if label_smoothing is 0.1 and vocab size is 150000.

label_smoothing=0.1
vocab_size=150000
confidence = 1.0 - label_smoothing
normalizing = -(confidence * math.log(confidence) + float(vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20))
# 12.01313558123328

from transformer-pytorch.

tonyhqanguyen avatar tonyhqanguyen commented on June 21, 2024
label_smoothing=0.1
vocab_size=150000
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / float(vocab_size - 1)
normalizing = -(confidence * math.log(confidence) + float(vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20))

>>> normalizing
1.516921364030397

This is what I get ...

from transformer-pytorch.

tunz avatar tunz commented on June 21, 2024

ah, I made a mistake. that is right.

from transformer-pytorch.

tonyhqanguyen avatar tonyhqanguyen commented on June 21, 2024

Ah ok thanks. I'm not sure why the loss is off the roof when I use label smoothing right now (~3000), but when I don't, it fluctuates at around 5. Let me know if you have any insight as to what I could try.

from transformer-pytorch.

tunz avatar tunz commented on June 21, 2024

What is the range(min/max) of your logit/softmax values for each vocab?

you said it's around 5 when label_smoothing is zero. it means

-1*log_softmax(answr_logit) == 5

and, if not, it's around 3000.

-(0.9*log_softmax(answer_logit) + sum((1/150000.0) * log_softmax(logit) for logit in other logits)) == 3000

then,

sum((1/150000.0) * log_softmax(logit) for logit in other_logits)) == -3000 + 0.9*5

If I assume that all logit values have the same value,

log_softmax(logit) == -3000 + 0.9*5

But, if their logit values have the same value, its softmax value should be around 1/150000, and log_softmax(logit) has to be around -12. It does not make sense in this case.

So, I guess the reason why your value is close to 3000 is some of your logit value is relatively smaller than others. Try to change initialization values of the embedding layer, and see how it's going. and, I still think this high loss may not be a big problem if the loss converges, or you can also reduce the value of label smoothing constant.

from transformer-pytorch.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.