Comments (11)
I think the random seed can only make the behavior the same for different runs.
But can not make the behavior the same for different epochs in a certain run.
from knowledge-distillation-pytorch.
@HisiFish Have you solve this problem? Is it possible to compute the teacher output from the same input?
--Updated--
Actually, it helps increasing the acc by 0.10-0.20%.
from knowledge-distillation-pytorch.
Can you clarify the question a bit more? What is the specific concern?
So the way the student model gets trained follows the same way of the teacher model. For one epoch, the training batches are used to compute KD loss to train the student. Then for another epoch, although dataloader is shuffled, KD loss should be still correct given new batches.
from knowledge-distillation-pytorch.
For example, If we have a dataset with 20 [image, label] pairs. We set the batch size to 4. So there are 5 iters in each epoch. We mark the origin data series indices 0~19.
In the code, we first fetch teacher outputs in one epoch, maybe the shuffled series indices is [[0,5,6,8],[7,9,2,4],[...],[...],[...]].
Then in kd training, another epoch, we need to caculate kd loss by (student outputs & teacher outputs & the labels). Now in current epoch, the indices may be shuffled to [[1,3,6,9],[10,2,8,7],[...],[...],[...]]. In code train.py:215, we get output_teacher_batch by i which is the new index of iters. While i is 0, the teacher outputs is from data [0,5,6,8] while the student outputs is from data [1,3,6,9].
I don't know whether I have the incorrect understanding. Thanks!
from knowledge-distillation-pytorch.
Sorry I did not fully understand. If you have time & are interested, could you run the test based on your understanding? Right now the KD-trained accuracies are consistently higher than native models, though it's only a bit higher. If your modification works better or makes better sense, feel free to make a pull request. Thanks in advance!
from knowledge-distillation-pytorch.
OK, I'll do that if I have a conclusion. Thanks.
from knowledge-distillation-pytorch.
Wait I think I get what you were saying. Basically, we need to verify that during training of the student model at each epoch, the batch sequence in the train dataloader stays the same as what was used during training of the teacher model. To that end, I think the PyTorch should be able to take care of that when specifying a random seed for reproducibility?
from knowledge-distillation-pytorch.
Maybe not.
It's easy to verify. The following is a simple example:
dataloader = ... # in which the shuffle switch is turned ON.
for i in range(10):
i = 0
for img_batch, label_batch in dataloader:
if i == 0:
print(label_batch)
i += 1
By comparing the first batch of 10 epoch, We can see the result.
from knowledge-distillation-pytorch.
Do you know what happens when you don't use enumerate but get batches via next(iter(data_loader))?
from knowledge-distillation-pytorch.
@HisiFish yes, you are right.
I put teacher_model and student model together. Finally, it works.
eg:
for img_batch, label_batch in dataloader:
y_student = f_student(img_batch)
with torch.no_grad():
y_teacher = f_teacher(img_batch)
refer to: https://github.com/szagoruyko/attention-transfer/blob/master/cifar.py
from knowledge-distillation-pytorch.
@HisiFish yes, you are right.
I put teacher_model and student model together. Finally, it works.
eg:for img_batch, label_batch in dataloader: y_student = f_student(img_batch) with torch.no_grad(): y_teacher = f_teacher(img_batch)refer to: https://github.com/szagoruyko/attention-transfer/blob/master/cifar.py
Hi @luhaifeng19947, I haven't followed the discussions here for a while. Are you interested in initiating a pull request?
from knowledge-distillation-pytorch.
Related Issues (20)
- An issue on loss function HOT 4
- 'Tensor' object is not callable HOT 1
- Error Cuda HOT 1
- missing training log for base cnn
- Box folder HOT 6
- I see the fitnets for reference HOT 2
- I think I couldn't prove how cnn_distill has highter performance than base_cnn. HOT 1
- How to train my own dataset HOT 1
- Box Folder HOT 2
- Computing teacher outpouts is called only onece? HOT 1
- teacher model in eval() mode but still update gradients? HOT 1
- boxed folder HOT 3
- in mnist folder,why teacher_mnist and stdudent_mnist do not contain the softmax? HOT 3
- Requirements.txt is outdated? HOT 5
- Why student use log_softmax(), while teacher use softmax() ?
- Are the distilled student models available for download?
- About "reduction" built in KLDivLoss
- no module named torch._dynamo
- regression problem can use this method? HOT 2
- Is student net really learn what teacher output? HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from knowledge-distillation-pytorch.