Giter Club home page Giter Club logo

dougbrion / pytorch-classification-uncertainty Goto Github PK

View Code? Open in Web Editor NEW
394.0 9.0 61.0 9.67 MB

This repo contains a PyTorch implementation of the paper: "Evidential Deep Learning to Quantify Classification Uncertainty"

Home Page: http://arxiv.org/abs/1806.01768

License: MIT License

Python 100.00%
paper pytorch evidential-deep-learning uncertainty-neural-networks classification mnist mnist-classification dirichlet-distributions uncertainty torchvision

pytorch-classification-uncertainty's Issues

Inappropriate loss function

Hello,

In your code, KL divergence is definded as

def kl_divergence(alpha, num_classes, device=None):
    if not device:
        device = get_device()
    ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
        torch.lgamma(sum_alpha)
        - torch.lgamma(alpha).sum(dim=1, keepdim=True)
        + torch.lgamma(ones).sum(dim=1, keepdim=True)
        - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
        .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
        .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl

but term(in first_term) bellow is meaningless, isn't it?
+ torch.lgamma(ones).sum(dim=1, keepdim=True)
because log gamma([1, 1, ...]) always be 0.

If you have other ideas, please let me know.

Thank you

KL Divergence

Hi, I am very interested in this work.

I find that kl_divergence loss in the code seems differ from that in the paper.
lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)

torch.sum(torch.lgamma(beta), dim=1, keepdim=True) for this part, I can't find it in the paper. Maybe I misunderstand it. Can you help me?

Fail to train in mini-Imagenet

I use the edl loss to train in mini-imagenet dataset with 64 classes, but the loss can't converge and the accuracy is very low.

Proof for loss function

Hi, thanks for your excellent work. But I can't understand the derived loss function, and the proof is provided in the supplementary material. So where can I download it?

how to get uncertainty?

Thanks for your work, this code is very helpful to me. But I have a question about how to calculate uncertainty. In paper, I find a formula like this picture.
微信图片_20201108115706
If i understand correctly, this 'u' is uncertainty, and 'b_k' is single category output of model, in other words, b_k is classificaton probability. So 'u' can get by 1 - 'b_k of all classes'.
But, in the section of 'Comparing in sample and out of sample classification', I do`t know why uncertainty of Master Yoda is 1.0, because classificaton probability of other classes do not look like 0.0.
Looking forward to your answer, thanks again.

function one_hot_embedding maybe lack `.to(device)`

hi guys, nice work! However, maybe you forget to make torch.eye to device in one_hot_embedding?

Modified one_hot_embedding in helpers.py as

def one_hot_embedding(labels, num_classes=10):
    # Convert to One Hot Encoding
    device = get_device()
    y = torch.eye(num_classes).to(device)
    return y[labels]

About annealing_step

Thank you very much for a nice work.
I want to train on CIFAR10 dataset use resnet20, the model needs to train 300 epochs, I don't know if I need to change annealing_step to 300 or remain unchanged as 10?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.