Giter Club home page Giter Club logo

class-balanced-loss's Introduction

Class-Balanced Loss Based on Effective Number of Samples

Tensorflow code for the paper:

Class-Balanced Loss Based on Effective Number of Samples
Yin Cui, Menglin Jia, Tsung-Yi Lin, Yang Song, Serge Belongie

Dependencies:

  • Python (3.6)
  • Tensorflow (1.14)

Datasets:

  • Long-Tailed CIFAR. We provide a download link that includes all the data used in our paper in .tfrecords format. The data was converted and generated by src/generate_cifar_tfrecords.py (original CIFAR) and src/generate_cifar_tfrecords_im.py (long-tailed CIFAR).

Effective Number of Samples:

For a visualization of the data and effective number of samples, please take a look at data.ipynb.

Key Implementation Details:

Training and Evaluation:

We provide 3 .sh scripts for training and evaluation.

  • On original CIFAR dataset:
./cifar_trainval.sh
  • On long-tailed CIFAR dataset (the hyperparameter IM_FACTOR is the inverse of "Imbalance Factor" in the paper):
./cifar_im_trainval.sh
  • On long-tailed CIFAR dataset using the proposed class-balanced loss (set non-zero BETA):
./cifar_im_trainval_cb.sh
  • Run Tensorboard for visualization:
tensorboard --logdir=./results --port=6006
  • The figure below are the results of running ./cifar_im_trainval.sh and ./cifar_im_trainval_cb.sh:

Training with TPU:

We train networks on iNaturalist and ImageNet datasets using Google's Cloud TPU. The code for this section is in tpu/. Our code is based on the official implementation of Training ResNet on Cloud TPU and forked from https://github.com/tensorflow/tpu.

Data Preparation:

  • Download datasets (except images) from this link and unzip it under tpu/. The unzipped directory tpu/raw_data/ contains the training and validation splits. For raw images, please download from the following links and put them into the corresponding folders in tpu/raw_data/:

  • Convert datasets into .tfrecords format and upload to Google Cloud Storage (gcs) using tpu/tools/datasets/dataset_to_gcs.py:

python dataset_to_gcs.py \
  --project=$PROJECT \
  --gcs_output_path=$GCS_DATA_DIR \
  --local_scratch_dir=$LOCAL_TFRECORD_DIR \
  --raw_data_dir=$LOCAL_RAWDATA_DIR

The following 3 .sh scripts in tpu/ can be used to train and evaluate models on iNaturalist and ImageNet using Cloud TPU. For more details on how to use Cloud TPU, please refer to Training ResNet on Cloud TPU.

Note that the image mean and standard deviation and input size need to be updated accordingly.

  • On ImageNet (ILSVRC 2012):
./run_ILSVRC2012.sh
  • On iNaturalist 2017:
./run_inat2017.sh
  • On iNaturalist 2018:
./run_inat2018.sh
  • The pre-trained models, including all logs viewable on tensorboard, can be downloaded from the following links:
Dataset Network Loss Input Size Download Link
ILSVRC 2012 ResNet-50 Class-Balanced Focal Loss 224 link
iNaturalist 2018 ResNet-50 Class-Balanced Focal Loss 224 link

Citation

If you find our work helpful in your research, please cite it as:

@inproceedings{cui2019classbalancedloss,
  title={Class-Balanced Loss Based on Effective Number of Samples},
  author={Cui, Yin and Jia, Menglin and Lin, Tsung-Yi and Song, Yang and Belongie, Serge},
  booktitle={CVPR},
  year={2019}
}

class-balanced-loss's People

Contributors

richardaecn avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

class-balanced-loss's Issues

Training on a new dataset...

Hi Yin

Thanks for sharing the code! I wanted to run your code on my own or some new dataset and get the corresponding evaluation metrics. Can you help me by giving some insights on how to use your code on some new datasets from scratch?

Thanks

Performance issues in tpu/models/ (by P3)

Hello! I've found a performance issue in tpu/models/: batch() should be called before map(), which could make your program more efficient. Here is the tensorflow document to support it.

Detailed description is listed below:

  • official/retinanet/dataloader.py/: dataset.batch(batch_size, drop_remainder=True)(here) should be called before dataset.map(_dataset_parser, num_parallel_calls=64)(here).
  • official/retinanet/dataloader.py/: dataset.batch(batch_size, drop_remainder=True)(here) should be called before dataset.map(_dataset_parser, num_parallel_calls=64)(here).
  • experimental/inception/inception_v3_old.py/: .batch(batch_size)(here) should be called before .map(parser)(here).

Besides, you need to check the function called in map()(e.g., parser called in .map(parser)) whether to be affected or not to make the changed code work properly. For example, if parser needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).

Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.

train CIFAR failed

I use ./cifar_trainval.sh to train,but it occurs some problems,why?

Traceback (most recent call last):
File "/environment/python/versions/miniconda3-4.7.12/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/environment/python/versions/miniconda3-4.7.12/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
target_list, run_metadata)
File "/environment/python/versions/miniconda3-4.7.12/envs/tf1.15/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
(0) Internal: Blas GEMM launch failed : a.shape=(128, 64), b.shape=(64, 10), m=128, n=10, k=64
[[{{node resnet/tower_0/fully_connected/dense/MatMul}}]]
[[resnet/tower_0/softmax_cross_entropy_loss/assert_broadcastable/is_valid_shape/has_valid_nonscalar_shape/has_invalid_dims/concat/_1549]]
(1) Internal: Blas GEMM launch failed : a.shape=(128, 64), b.shape=(64, 10), m=128, n=10, k=64
[[{{node resnet/tower_0/fully_connected/dense/MatMul}}]]
0 successful operations.
0 derived errors ignored.

CB-Loss makes no sense when n_y is large

When n_y is large, it seems that the loss weights α are always equal to 1. If so, CB Loss makes no sense. @richardaecn
For example, 10,000 images belong to class A, and only 1,000 images belong to class B. Then the CB weights are [1,1], no matter how much β is.
Please correct me if I have any misunderstanding. Thanks.

Performance issues in tpu/models/official/retinanet/dataloader.py

Hello,I found a performance issue in the definition of __call__ ,
tpu/models/official/retinanet/dataloader.py,
dataset = dataset.map(_process_example) was called without num_parallel_calls.
I think it will increase the efficiency of your program if you add this.

The same issues also exist in dataset = dataset.map(parser) ,
dataset = dataset.repeat().map(parser)

Here is the documemtation of tensorflow to support this thing.

Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.

About data augmentation

Hi, may I ask a question about the description in the paper?

As you mentioned in paper in the last paragraph of chapter 3.1:

"the stronger the data augmentation is, the smaller the N will be. The small neighboring region of a sample is a way to capture all near-duplicates and instances that can be obtained by data augmentation"

Shouldn't stronger data augmentation technique provide more samples of the same class(S) that makes the volume of them (N) larger?

multi label problems

How I set weights?
If image1 (have class 1 and 2 and class3), the weight is (class1_num + class2_num + class3_num)?

focal loss modulator

how to infer the modulator

the code in your repo

modulator = tf.exp(-gamma * labels * logits - gamma * tf.log1p(
          tf.exp(-1.0 * logits)))

for focal loss in tensorflow/models/blob/master/research/object_detection
the focal loss form is the same as what is shown in paper

    prediction_probabilities = tf.sigmoid(prediction_tensor)
    p_t = ((target_tensor * prediction_probabilities) +
           ((1 - target_tensor) * (1 - prediction_probabilities)))
    modulating_factor = 1.0
    if self._gamma:
      modulating_factor = tf.pow(1.0 - p_t, self._gamma)

Could you please tell me how to transform the paper form to your form?

Thank you very much!

Nan values in focal loss

Hi, I am using your implementation of focal loss, and sometimes the value calculated is nan. I realized that it is due to the normalization in terms of positive samples in the batch since I am working with 3D data, and I can't have big batches. I have a very unbalanced dataset as well. That causes that some of my batches are composed of only negative samples, so the normalization ends up having a zero division. How would you recommend I perform the normalization in this case?

My implementation in Pytorch doesn't work

Dear authors, thanks for your greate effort to make your code open-source.
I re-implement you CB-focal Loss in Pytorch(Both in your tf-version and my own version), but can't reach the performance reported in your paper.

This is my code. Could you please have a check whether there is something wrong with my code?
output: [batch_size, num_classes]
label: [batch_size]
catList: [num_classes] a list of sampler numbers for each class

image

Question About balanced Loss

Hi,

thanks for sharing the code and for your great work. I have a question about the loss in your paper why you add +1 in (1) formula exactly (1-p)(En-1+1). Where does it came from? is the initial expected value?

Thanks

baseline for long-tail cifar10 is 77.47%

image

Because my implementation for cb-focal in pytorch can only reach 77% accuracy for long-tail cifar10. I have ran your source code without any change. But the baseline is 77.5% (My baseline for pytorch in actually 75%.), which is almost 2.7% higher than the results reported in your paper.
Does this mean cb-focal is only 2% higher than baseline?

En computed in batch or whole dataset?

I try to compted En as a weight in batch data, but loss quickly change to NAN. I always try to compute in whole dataset, img_per_class =[900000,700000,60000], because img_num is very large, beta^n is almost equal 0. then, i got En like [1e-4,1e-4,1e-4]. I think the weight can not handle the imbalance dataset. @richardaecn

inference using the models

Respected Authors,
Firstly, thank you for releasing the code. Is code available for inference from given pre-trained models? It would be really helpful if you could provide the same along with your current repository.
Thanks in advance

mini-batch E_{n} doesn't work

In each batch, I firstly compute one independent E_{n} for that batch, but it doesn't work at all; but according to your code, E_{n} is global, there exists only one value for entire mini-batch GD optimization process, it works evidently. I want to know the reason?

Class balanced loss code

Hi!
First of all, thank you ver much for your code, it's a great work!

I would like to know if you could tell me the part where the class-balanced-loss is implemented among all the files that you include.

I've been looking for it and I'm a little lost.

Thank you very much in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.