Comments (21)
Ok. I understand your words. This is my naive implementation. I only tested it on MNIST, so the speed is not very important 😅 :
x, target = x.to(device), target.to(device)
with torch.no_grad():
out = teacher(x)
soft_target = F.softmax(out/T, dim=1)
hard_target = target
out = student(x) ## this is the input to softmax
logp = F.log_softmax(out/T, dim=1)
loss_soft_target = -torch.mean(torch.sum(soft_target * logp, dim=1))
loss_hard_target = nn.CrossEntropyLoss()(out, hard_target)
loss = loss_soft_target * T * T + alpha * loss_hard_target
from knowledge-distillation-pytorch.
I see. You can refer to the definition/document of PyTorch's KL Divergence loss (KLDivLos). Here it requires inputs to be probability distributions and log-probability distributions, and that's why we're using softmax and log-softmax on teacher/student outputs (which were raw scores).
from knowledge-distillation-pytorch.
@xmfbit Indeed, initially I was trying to directly implement cross entropy with the soft targets. However, note in PyTorch, the built-in CrossEntropy loss function only takes “(output, target)” where the target (i.e., label) is not one-hot encoded (which is what KD loss needs). That's why I turned to using KL divergence, since they two will lead to the same optimization results, and KL divergence works naturally with our data representations.
You're welcome to try to define a customized CrossEntropy loss function that also leverages PyTorch’s optimized C-backend (you could also define one from scratch, but that might be very slow). If successful, please let us know. Thanks!
from knowledge-distillation-pytorch.
yes,it do not join backward.so wo can ignore this term
from knowledge-distillation-pytorch.
H(p, q): Cross Entropy of p, q
from knowledge-distillation-pytorch.
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the grad
refer to the same instance with grad_check
, so they must be equal! ) I paste a right one below.
And I want to figure out that the author's implementation of KD loss using torch.mm.KLDivLoss
in the code will cause the gradient scaled by the number of classification categories
, compared with using CrossEntropy
. See the torch documentation for detail.
size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch
import torch.nn as nn
import torch.nn.functional as F
# sample number
N = 10
# category number
C = 5
# softmax output of teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# logit output of student
s = torch.rand(N, C, requires_grad=True)
# softmax output of student, T = 1
q = torch.softmax(s, dim=1)
# KL Diverse
# this is the implementation of the author's
# torch will do element mean because it is the default option
# kl_loss = nn.KLDivLoss()(torch.log(q), p)
# I think this should be the right solution
kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean()
kl_loss.backward(retain_graph=True)
print 'grad using KL DivLoss'
print s.grad
# clear the grad
s.grad.zero_()
# bug2: should not do element wise mean operation
# ce_loss = torch.mean(-torch.log(q) * p)
ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1))
ce_loss.backward()
print 'grad using ce loss'
print s.grad
# the real gradient of s should be `(q - p) / batch_size`
print 'real grad, should be (q-p) / batch_size'
print (q - p) / N
@peterliht Could you check this?
from knowledge-distillation-pytorch.
@xmfbit H(p) is constant, right ?
from knowledge-distillation-pytorch.
it would better help me understand your question if you could mention which file & lines that you were referring to.
from knowledge-distillation-pytorch.
oh It is in model /net.py/ loss_fn_kd function, line 107.
from knowledge-distillation-pytorch.
Thanks for your reply.
from knowledge-distillation-pytorch.
@peterliht Why KL divergence is used to compute KD-loss? The paper "Distilling the Knowledge in a Neural Network" says,
The first objective function is the cross entropy with the soft targets
The KD-loss should be -\sum_{i=1}^C soft_target_i * \log(softmax(student_cls_output_i / T))
?
from knowledge-distillation-pytorch.
Has anyone tried cross-entropy? Does it work better or worse than KL?
from knowledge-distillation-pytorch.
Has anyone tried cross-entropy? Does it work better or worse than KL?
No. They should lead to same or similar result given the above discussion .
from knowledge-distillation-pytorch.
Has anyone tried cross-entropy? Does it work better or worse than KL?
@michaelklachko The gradients of student's output are same using KL divergence and classic KD loss by Hinton's paper. You can refer to the figure given by nowgood. Use this code to check it numerically (p, q are different from nowgood's figure).
import torch
import torch.nn as nn
import torch.nn.functional as F
N = 10
C = 5
# softmax output by teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# softmax output by student
q = torch.softmax(torch.rand(N, C), dim=1)
#q = torch.ones(N, C)
q.requires_grad = True
# KL Diverse
kl_loss = nn.KLDivLoss()(torch.log(q), p)
kl_loss.backward()
grad = q.grad
q.grad.zero_()
ce_loss = torch.mean(torch.log(q) * p)
ce_loss.backward()
grad_check = q.grad
print grad
print grad_check
from knowledge-distillation-pytorch.
Great, thanks!
from knowledge-distillation-pytorch.
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the
grad
refer to the same instance withgrad_check
, so they must be equal! ) I paste a right one below.And I want to figure out that the author's implementation of KD loss using
torch.mm.KLDivLoss
in the code will cause the gradient scaled bythe number of classification categories
, compared with usingCrossEntropy
. See the torch documentation for detail.size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch import torch.nn as nn import torch.nn.functional as F # sample number N = 10 # category number C = 5 # softmax output of teacher p = torch.softmax(torch.rand(N, C), dim=1) # logit output of student s = torch.rand(N, C, requires_grad=True) # softmax output of student, T = 1 q = torch.softmax(s, dim=1) # KL Diverse # this is the implementation of the author's # torch will do element mean because it is the default option # kl_loss = nn.KLDivLoss()(torch.log(q), p) # I think this should be the right solution kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean() kl_loss.backward(retain_graph=True) print 'grad using KL DivLoss' print s.grad # clear the grad s.grad.zero_() # bug2: should not do element wise mean operation # ce_loss = torch.mean(-torch.log(q) * p) ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1)) ce_loss.backward() print 'grad using ce loss' print s.grad # the real gradient of s should be `(q - p) / batch_size` print 'real grad, should be (q-p) / batch_size' print (q - p) / N
@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.
from knowledge-distillation-pytorch.
Great, thanks!
@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the
grad
refer to the same instance withgrad_check
, so they must be equal! ) I paste a right one below.
And I want to figure out that the author's implementation of KD loss usingtorch.mm.KLDivLoss
in the code will cause the gradient scaled bythe number of classification categories
, compared with usingCrossEntropy
. See the torch documentation for detail.size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.
reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’
import torch import torch.nn as nn import torch.nn.functional as F # sample number N = 10 # category number C = 5 # softmax output of teacher p = torch.softmax(torch.rand(N, C), dim=1) # logit output of student s = torch.rand(N, C, requires_grad=True) # softmax output of student, T = 1 q = torch.softmax(s, dim=1) # KL Diverse # this is the implementation of the author's # torch will do element mean because it is the default option # kl_loss = nn.KLDivLoss()(torch.log(q), p) # I think this should be the right solution kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean() kl_loss.backward(retain_graph=True) print 'grad using KL DivLoss' print s.grad # clear the grad s.grad.zero_() # bug2: should not do element wise mean operation # ce_loss = torch.mean(-torch.log(q) * p) ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1)) ce_loss.backward() print 'grad using ce loss' print s.grad # the real gradient of s should be `(q - p) / batch_size` print 'real grad, should be (q-p) / batch_size' print (q - p) / N
@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.
Try to decrease your lr or train with CE loss to check bugs
from knowledge-distillation-pytorch.
@Bo396543018 The same here, did you solve the problem?
from knowledge-distillation-pytorch.
Hi,
Using #2 (comment) implementation, I am getting large Soft loss values like 200, but using https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 implementation, my Soft loss value is like 1e-6. It's surprising because we expect results to be similar.
Any help will be grateful. Thanks!
from knowledge-distillation-pytorch.
So what is the answer for: why softmax for teacher output , but log softmax for student output ?
from knowledge-distillation-pytorch.
@pratikchhapolika
https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
it's just how pytorch KLDIvLoss()
takes the arguments
input=predicted log softmax
target=softmax if log_target==False
as shown above KL Divergence = CrossEntropy - Entropy
log(p(y)) = student output
q(y) = teacher output
from knowledge-distillation-pytorch.
Related Issues (20)
- An issue on loss function HOT 4
- 'Tensor' object is not callable HOT 1
- Error Cuda HOT 1
- missing training log for base cnn
- Box folder HOT 6
- I see the fitnets for reference HOT 2
- I think I couldn't prove how cnn_distill has highter performance than base_cnn. HOT 1
- How to train my own dataset HOT 1
- Box Folder HOT 2
- Computing teacher outpouts is called only onece? HOT 1
- teacher model in eval() mode but still update gradients? HOT 1
- boxed folder HOT 3
- in mnist folder,why teacher_mnist and stdudent_mnist do not contain the softmax? HOT 3
- Requirements.txt is outdated? HOT 5
- Why student use log_softmax(), while teacher use softmax() ?
- Are the distilled student models available for download?
- About "reduction" built in KLDivLoss
- no module named torch._dynamo
- regression problem can use this method? HOT 2
- Is student net really learn what teacher output? HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from knowledge-distillation-pytorch.