Giter Club home page Giter Club logo

Comments (10)

martiansideofthemoon avatar martiansideofthemoon commented on May 17, 2024

Hi Ahmad, thanks for your interest! An accuracy of 31.9% indicates worse than random guess performance. A few questions to help debug this,

  1. What is the class distribution of the extracted data?
  2. What scheme did you use, RANDOM / WIKI?
  3. What was the dev set accuracy of the teacher model?

from language.

ahmadrash avatar ahmadrash commented on May 17, 2024

Thanks a lot for the prompt response.

  1. The class distribution is [26.76%, 26.31%, 46.93%] respectively
  2. I used DATA_SCHEME="random_ed_k_uniform"
  3. Dev set accuracy of the teacher model is 0.851

from language.

martiansideofthemoon avatar martiansideofthemoon commented on May 17, 2024

Hi Ahmad,
1,2 and 3 look good to me. A few more follow-up questions,

  1. I guess you are using BERT-large?
  2. Are you using this file to train the student model? https://github.com/google-research/language/blob/master/language/bert_extraction/steal_bert_classifier/models/run_classifier_distillation.py
  3. Is the training loss decreasing? (just confirming if the weight updates are happening)
  4. Does the same script work for SST2 / SQuAD?

from language.

ahmadrash avatar ahmadrash commented on May 17, 2024

Thanks Kalpesh,

  1. Yes I am using BERT_large

  2. Yes I am using the file.

  3. I am adding the loss curve from Tensorboard. It shows oscillations.
    mnli_loss

  4. I am still running the other experiments.

from language.

martiansideofthemoon avatar martiansideofthemoon commented on May 17, 2024

regarding your curve, how many epochs are you training it for / what's your batch size? A loss of 1.1 indicates nothing is being learnt, but I do see a strong decrease after the first few ~10k steps. Also, what is your learning rate, optimizer and learning rate schedule? Finally, what hardware are you using?

from language.

ahmadrash avatar ahmadrash commented on May 17, 2024

I am training it for 3 epochs. I have a batch size of 8 on an NVIDIA V100 GPU. The learning rate,optimizer and schedule are the default in the script.

--learning_rate=3e-5
--warmup_propotion=0.1

And optimizer is same as the default for BERT

from language.

martiansideofthemoon avatar martiansideofthemoon commented on May 17, 2024

I think the batch size might be the issue, learning is less stable for RANDOM than the original MNLI, and smaller batch sizes (hence weaker gradient estimates) could put the model off the optimization path. I'd recommend trying batch size 32. If it doesn't fit on the GPU, you could try using BERT-base or gradient accumulation.

Another thing you could try is a learning rate decay. From your graph, it is clear that the training loss reduces during the warmup phase of training, but then the learning rate is too high and a bad gradient (from a small batch) can put off the optimization. You could also simply try smaller learning rates, maybe 1e-5

from language.

ahmadrash avatar ahmadrash commented on May 17, 2024

Thanks a lot for the suggestions. I will try these and report back.

from language.

ahmadrash avatar ahmadrash commented on May 17, 2024

Thanks Kalesh! I was able to to get 78 on MNLI dev and 90 on SST-2 reducing the learning rate to 1e-5. The loss curve still is not ideal but much better than what we were seeing before.

from language.

Jimntu avatar Jimntu commented on May 17, 2024

Hi, I am a beginner in deep learning and have little experience in implementing the code. May I ask how can you draw the loss curve from tensorboard? I would really appreciate if you can help me!

from language.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.