Giter Club home page Giter Club logo

Comments (8)

michaelklachko avatar michaelklachko commented on June 22, 2024 4

Here's how I did it:

for i in range(num_train_batches):
	input = train_inputs[i*batch_size:(i+1)*batch_size]
	label = train_labels[i*batch_size:(i+1)*batch_size]
	output = model(input)
        teacher_output = teacher_model(input)
	loss = nn.KLDivLoss()(F.log_softmax(output / T, dim=1), F.softmax(teacher_output / T, dim=1)) * alpha * T * T + F.cross_entropy(output, label) * (1 - alpha)

I see a consistent improvement of ~1% on CIFAR-10 with T=6 and alpha=0.9.

from knowledge-distillation-pytorch.

DavidBegert avatar DavidBegert commented on June 22, 2024 1

Am I wrong that the teacher outputs are supposed to be aligned with the student outputs?

from knowledge-distillation-pytorch.

haitongli avatar haitongli commented on June 22, 2024

It makes sense that they're different. The idea of knowledge distillation goes beyond simply adding some supervision like final output labels. My understanding of knowledge distillation is that, the student is not mimicking the teacher's output labels, but learning on its own with training data, and being regularized with the "dark knowledge" from teacher. If you take a look at the KD loss function, it's a joint of hard labels (training samples) and soft labels (teacher outputs). Hinton's paper had an insightful discussion around this.

from knowledge-distillation-pytorch.

zym1119 avatar zym1119 commented on June 22, 2024

If alpha=0.95, t=6, then alphatt would be far larger than 1-alpha, i don't think the student is not mimicking the teachers' output, the loss is almost depending on the kl-divergence between the student and teacher.
Have you tried training a student on ImageNet?

from knowledge-distillation-pytorch.

haitongli avatar haitongli commented on June 22, 2024

It might be interesting to have a hyperparameter explorations (using the search_hyperparams.py) for alpha and T (I had done that for my course project, but not intensively due to time/resource limitation). If we use a near-one alpha (forcing student to mimic teacher more) with an improper temperature, sometimes we would notice a sharp drop in accuracy. Sometimes though, even with a small alpha, much knowledge could be distilled from teacher into student (measured by accuracy improvement). My point was simply that, the loss contribution from distillation part could be more complicated than the intuitive sense of "mimicking" teacher's output labels. It also relates to the bias-variance tradeoff from a traditional ML perspective.

No, I haven't touched ImageNet due to resource constraint...

from knowledge-distillation-pytorch.

zym1119 avatar zym1119 commented on June 22, 2024

thx for answering, i need to have a detailed look at your report and do some experiments

from knowledge-distillation-pytorch.

zym1119 avatar zym1119 commented on June 22, 2024

I split the loss into two parts, one is the cross entropy between outputs and teacher outputs with temperature T, the other is the cross entropy between outputs and labels. the first loss is called "soft" and the other is "hard"
In your code, due to the mis-align of teacher output and student output, the "soft" loss is always a const during training with value 0.045, and the "hard" loss is optimized from 0.198 to a relatively low value.
This means the network is always learning from the target while not the teacher, the "soft" loss is always a noisy for student, i guess this became some sort of regularization and cause the student network to have slightly better generalization than before.

from knowledge-distillation-pytorch.

chenxshuo avatar chenxshuo commented on June 22, 2024

Am I wrong that the teacher outputs are supposed to be aligned with the student outputs?

I think these two outputs should be aligned. It seems that @peterliht has recorded the teacher output in the beginning and reuse the output during training student by the index through enumerate(dataloader). But the dataloader shuffles everytime so the results are not the same for teacher and student.

from knowledge-distillation-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.