Giter Club home page Giter Club logo

Comments (21)

xmfbit avatar xmfbit commented on July 21, 2024 19

Ok. I understand your words. This is my naive implementation. I only tested it on MNIST, so the speed is not very important 😅 :

        x, target = x.to(device), target.to(device)
        with torch.no_grad():
            out = teacher(x)
            soft_target = F.softmax(out/T, dim=1)
        hard_target = target
        out = student(x)  ## this is the input to softmax
        logp = F.log_softmax(out/T, dim=1)
        loss_soft_target = -torch.mean(torch.sum(soft_target * logp, dim=1))
        loss_hard_target = nn.CrossEntropyLoss()(out, hard_target)
        loss = loss_soft_target * T * T + alpha * loss_hard_target

from knowledge-distillation-pytorch.

haitongli avatar haitongli commented on July 21, 2024 13

I see. You can refer to the definition/document of PyTorch's KL Divergence loss (KLDivLos). Here it requires inputs to be probability distributions and log-probability distributions, and that's why we're using softmax and log-softmax on teacher/student outputs (which were raw scores).

from knowledge-distillation-pytorch.

haitongli avatar haitongli commented on July 21, 2024 11

@xmfbit Indeed, initially I was trying to directly implement cross entropy with the soft targets. However, note in PyTorch, the built-in CrossEntropy loss function only takes “(output, target)” where the target (i.e., label) is not one-hot encoded (which is what KD loss needs). That's why I turned to using KL divergence, since they two will lead to the same optimization results, and KL divergence works naturally with our data representations.

You're welcome to try to define a customized CrossEntropy loss function that also leverages PyTorch’s optimized C-backend (you could also define one from scratch, but that might be very slow). If successful, please let us know. Thanks!

from knowledge-distillation-pytorch.

nowgood avatar nowgood commented on July 21, 2024 4

yes,it do not join backward.so wo can ignore this term

from knowledge-distillation-pytorch.

nowgood avatar nowgood commented on July 21, 2024 3

image

image

H(p, q): Cross Entropy of p, q

from knowledge-distillation-pytorch.

xmfbit avatar xmfbit commented on July 21, 2024 3

Great, thanks!

@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the grad refer to the same instance with grad_check, so they must be equal! ) I paste a right one below.

And I want to figure out that the author's implementation of KD loss using torch.mm.KLDivLoss in the code will cause the gradient scaled by the number of classification categories, compared with using CrossEntropy. See the torch documentation for detail.

size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.

reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’

import torch
import torch.nn as nn
import torch.nn.functional as F
# sample number
N = 10
# category number
C = 5
# softmax output of teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# logit output of student
s = torch.rand(N, C, requires_grad=True)
# softmax output of student, T = 1
q = torch.softmax(s, dim=1)
# KL Diverse
# this is the implementation of the author's
# torch will do element mean because it is the default option
# kl_loss = nn.KLDivLoss()(torch.log(q), p)
# I think this should be the right solution
kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean()
kl_loss.backward(retain_graph=True)
print 'grad using KL DivLoss'
print s.grad
# clear the grad
s.grad.zero_()
# bug2: should not do element wise mean operation
# ce_loss = torch.mean(-torch.log(q) * p)
ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1))
ce_loss.backward()
print 'grad using ce loss'
print s.grad
# the real gradient of s should be `(q - p) / batch_size`
print 'real grad, should be (q-p) / batch_size'
print (q - p) / N

@peterliht Could you check this?

from knowledge-distillation-pytorch.

PaTricksStar avatar PaTricksStar commented on July 21, 2024 2

@xmfbit H(p) is constant, right ?

from knowledge-distillation-pytorch.

haitongli avatar haitongli commented on July 21, 2024

it would better help me understand your question if you could mention which file & lines that you were referring to.

from knowledge-distillation-pytorch.

PaTricksStar avatar PaTricksStar commented on July 21, 2024

oh It is in model /net.py/ loss_fn_kd function, line 107.

from knowledge-distillation-pytorch.

PaTricksStar avatar PaTricksStar commented on July 21, 2024

Thanks for your reply.

from knowledge-distillation-pytorch.

xmfbit avatar xmfbit commented on July 21, 2024

@peterliht Why KL divergence is used to compute KD-loss? The paper "Distilling the Knowledge in a Neural Network" says,

The first objective function is the cross entropy with the soft targets

The KD-loss should be -\sum_{i=1}^C soft_target_i * \log(softmax(student_cls_output_i / T))

from knowledge-distillation-pytorch.

michaelklachko avatar michaelklachko commented on July 21, 2024

Has anyone tried cross-entropy? Does it work better or worse than KL?

from knowledge-distillation-pytorch.

PaTricksStar avatar PaTricksStar commented on July 21, 2024

Has anyone tried cross-entropy? Does it work better or worse than KL?

No. They should lead to same or similar result given the above discussion .

from knowledge-distillation-pytorch.

xmfbit avatar xmfbit commented on July 21, 2024

Has anyone tried cross-entropy? Does it work better or worse than KL?

@michaelklachko The gradients of student's output are same using KL divergence and classic KD loss by Hinton's paper. You can refer to the figure given by nowgood. Use this code to check it numerically (p, q are different from nowgood's figure).

import torch
import torch.nn as nn
import torch.nn.functional as F

N = 10
C = 5
# softmax output by teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# softmax output by student
q = torch.softmax(torch.rand(N, C), dim=1)
#q = torch.ones(N, C)
q.requires_grad = True
# KL Diverse
kl_loss = nn.KLDivLoss()(torch.log(q), p)
kl_loss.backward()

grad = q.grad

q.grad.zero_()
ce_loss = torch.mean(torch.log(q) * p)
ce_loss.backward()

grad_check = q.grad
print grad
print grad_check

from knowledge-distillation-pytorch.

michaelklachko avatar michaelklachko commented on July 21, 2024

Great, thanks!

from knowledge-distillation-pytorch.

Bo396543018 avatar Bo396543018 commented on July 21, 2024

Great, thanks!

@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the grad refer to the same instance with grad_check, so they must be equal! ) I paste a right one below.

And I want to figure out that the author's implementation of KD loss using torch.mm.KLDivLoss in the code will cause the gradient scaled by the number of classification categories, compared with using CrossEntropy. See the torch documentation for detail.

size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.

reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’

import torch
import torch.nn as nn
import torch.nn.functional as F
# sample number
N = 10
# category number
C = 5
# softmax output of teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# logit output of student
s = torch.rand(N, C, requires_grad=True)
# softmax output of student, T = 1
q = torch.softmax(s, dim=1)
# KL Diverse
# this is the implementation of the author's
# torch will do element mean because it is the default option
# kl_loss = nn.KLDivLoss()(torch.log(q), p)
# I think this should be the right solution
kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean()
kl_loss.backward(retain_graph=True)
print 'grad using KL DivLoss'
print s.grad
# clear the grad
s.grad.zero_()
# bug2: should not do element wise mean operation
# ce_loss = torch.mean(-torch.log(q) * p)
ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1))
ce_loss.backward()
print 'grad using ce loss'
print s.grad
# the real gradient of s should be `(q - p) / batch_size`
print 'real grad, should be (q-p) / batch_size'
print (q - p) / N

@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.

from knowledge-distillation-pytorch.

PaTricksStar avatar PaTricksStar commented on July 21, 2024

Great, thanks!

@michaelklachko Sorry but I found there are some bugs in the code I provided. (when checking the gradient using print, the grad refer to the same instance with grad_check, so they must be equal! ) I paste a right one below.
And I want to figure out that the author's implementation of KD loss using torch.mm.KLDivLoss in the code will cause the gradient scaled by the number of classification categories, compared with using CrossEntropy. See the torch documentation for detail.

size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample.

reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘elementwise_mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘elementwise_mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: ‘elementwise_mean’

import torch
import torch.nn as nn
import torch.nn.functional as F
# sample number
N = 10
# category number
C = 5
# softmax output of teacher
p = torch.softmax(torch.rand(N, C), dim=1)
# logit output of student
s = torch.rand(N, C, requires_grad=True)
# softmax output of student, T = 1
q = torch.softmax(s, dim=1)
# KL Diverse
# this is the implementation of the author's
# torch will do element mean because it is the default option
# kl_loss = nn.KLDivLoss()(torch.log(q), p)
# I think this should be the right solution
kl_loss = (nn.KLDivLoss(reduction='none')(torch.log(q), p)).sum(dim=1).mean()
kl_loss.backward(retain_graph=True)
print 'grad using KL DivLoss'
print s.grad
# clear the grad
s.grad.zero_()
# bug2: should not do element wise mean operation
# ce_loss = torch.mean(-torch.log(q) * p)
ce_loss = torch.mean(torch.sum(-torch.log(q) * p, dim=1))
ce_loss.backward()
print 'grad using ce loss'
print s.grad
# the real gradient of s should be `(q - p) / batch_size`
print 'real grad, should be (q-p) / batch_size'
print (q - p) / N

@peterliht Could you check this?
when I tried to use the loss code provided by author in new task which has 1000 categories, I found the kl loss term is too small , nearly 1e-6, and both in CIFAR10 and my task, i seems the kl loss is never decrease. Can you tell me how to fix this problem, thank you.

Try to decrease your lr or train with CE loss to check bugs

from knowledge-distillation-pytorch.

erichhhhho avatar erichhhhho commented on July 21, 2024

@Bo396543018 The same here, did you solve the problem?

from knowledge-distillation-pytorch.

sidsingla avatar sidsingla commented on July 21, 2024

Hi,
Using #2 (comment) implementation, I am getting large Soft loss values like 200, but using https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 implementation, my Soft loss value is like 1e-6. It's surprising because we expect results to be similar.
Any help will be grateful. Thanks!

from knowledge-distillation-pytorch.

pratikchhapolika avatar pratikchhapolika commented on July 21, 2024

So what is the answer for: why softmax for teacher output , but log softmax for student output ?

from knowledge-distillation-pytorch.

jl749 avatar jl749 commented on July 21, 2024

@pratikchhapolika
https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
it's just how pytorch KLDIvLoss() takes the arguments
input=predicted log softmax
target=softmax if log_target==False

as shown above KL Divergence = CrossEntropy - Entropy
image
log(p(y)) = student output
q(y) = teacher output

from knowledge-distillation-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.