Giter Club home page Giter Club logo

polyloss-pytorch's Introduction

polyloss-pytorch

class PolyLoss(softmax=False, ce_weight=None, reduction='mean', epsilon=1.0)

This class is used to compute the Poly-1 Loss between the input and target tensors.

Poly-1 Loss is defined as

The predication input is compared with ground truth target. Input is expected to have shape BNHW[D] where N is number of classes. It can contains either logits or probabilities for each class, if passing logits as input, set softmax=True. target is expected to have shape B1HW[D], BHW[D] or BNHW[D] (one-hot format).

epsilon is the first polynomial coefficient in cross-entropy loss, in order to achieve best result, this value needs to be adjusted for different task and data. The optimal value for epsilon can be found through hyperparameter tunning

The original paper: Zhaoqi, L. et. al. (2022): PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions, 2022

Parameters

  • softmax (bool) – if True, apply a softmax function to the prediction (i.e.input)
  • ce_weight(Tensor,optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size N(it's same as weight argument for nn.CrossEntropyLoss class)
  • reduction(string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed.
  • epsilon: the first polynomial coefficient. defaults to be 1.

a colab tutorial: Use PolyLoss with Fast.ai and Weights & Biases

in tutorial in colab, I provided an example of how to use PolyLoss in fastai (super easy!) and do a hyperparameter search with Weights & Biases.

How to Use

Examples

from PolyLoss import to_one_hot, PolyLoss

# Example of target in one-hot encoded format
loss = PolyLoss(softmax=True)
B, C, H, W = 2, 5, 3, 3
input = torch.rand(B, C, H, W, requires_grad=True)
target = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = to_one_hot(target[:, None, ...], num_classes=C)
output = loss(input, target)
output.backward()



# Example of target not in one-hot encoded format
loss = PolyLoss(softmax=True)
B, C, H, W = 2, 5, 3, 3
input = torch.rand(B, C, H, W, requires_grad=True)
target = torch.randint(low=0, high=C - 1, size=(B, 1, H, W)).long()
output = loss(input, target)
output.backward()


# Example of PolyBCELoss
from PolyLoss import PolyBCELoss
loss = PolyBCELoss()
B, H, W = 2, 3, 3
input = torch.rand(B, H, W, requires_grad=True)
target = torch.empty(B,H,W).random_(2)
output = loss(input, target)
output.backward()

polyloss-pytorch's People

Contributors

glenn-jocher avatar yiyixuxu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

polyloss-pytorch's Issues

Error scatter pytorch

Hello,

While making one-hot encoding of labels. I got these types of errors at scatter PyTorch.

/opt/conda/conda-bld/pytorch_1646756402876/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:276: operator(): block: [66,0,0], thread: [5,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.

Note : My labels is a mask with classes=2 and ignore_index=255 (label.unique()=[0,1,255]

NameError: name 'target_idx' is not defined

README example is not reproducible as target_idx does not exist.

import torch
from PolyLoss import to_one_hot, PolyLoss
# Example of target in one-hot encoded format
loss = PolyLoss(softmax=True)
B, C, H, W = 2, 5, 3, 3
input = torch.rand(B, C, H, W, requires_grad=True)
target = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
target = to_one_hot(target_idx[:, None, ...], num_classes=C)
output = loss(input, target)
output.backward()

...

Traceback (most recent call last):
  File "/Users/glennjocher/PycharmProjects/yolov5-pro/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-0534d4642825>", line 8, in <module>
    target = to_one_hot(target_idx[:, None, ...], num_classes=C)
NameError: name 'target_idx' is not defined

Poly Loss image segmentation

Hello, Thank you for your work. I want to use the poly loss for image segmentation. I have input [B,C,H,W] and mask as [B,H,W]. I have ignore_indexes too in masks. how can I use your provided loss function correctly? Thank you

Passing probability into CrossEntropy

torch.nn.CrossEntropyLoss expect you to pass raw, unnormalized logit, however, either logit or probability is passing into cross entropy loss. If the probability is used will lead to unintended result.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.