Giter Club home page Giter Club logo

saddle-longtail's Introduction

Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data

Harsh Rangwani*, Sumukh K Aithal*, Mayank Mishra, R. Venkatesh Babu

This is the official PyTorch implementation for our NeurIPS'22 paper: Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data [OpenReview] [arXiv]

PWC PWC PWC PWC

UPDATE : We integrated our method with GLMC (CVPR 2023). Our method leads to ~2% gain over GLMC (SotA) ๐Ÿ˜„ [link].

Abstract

Long Tail Saddle Points

Real-world datasets exhibit imbalances of varying types and degrees. Several techniques based on re-weighting and margin adjustment of loss are often used to enhance the performance of neural networks, particularly on minority classes. In this work, we analyze the class-imbalanced learning problem by examining the loss landscape of neural networks trained with re-weighting and margin based techniques. Specifically, we examine the spectral density of Hessian of class-wise loss, through which we observe that the network weights converges to a saddle point in the loss landscapes of minority classes. Following this observation, we also find that optimization methods designed to escape from saddle points can be effectively used to improve generalization on minority classes. We further theoretically and empirically demonstrate that Sharpness-Aware Minimization (SAM), a recent technique that encourages convergence to a flat minima, can be effectively used to escape saddle points for minority classes. Using SAM results in a 6.2% increase in accuracy on the minority classes over the state-of-the-art Vector Scaling Loss, leading to an overall average increase of 4% across imbalanced datasets.

TLDR: Tail class loss landscape converges to a saddle point in imbalanced datasets and SAM can effectively escape from these solutions.

Getting started

  • Requirements

    • pytorch 1.9.1
    • torchvision 0.10.1
    • wandb 0.12.2
    • timm 0.5.5
    • prettytable 2.2.0
    • scikit-learn
    • matplotlib
    • tensorboardX
  • Installation

git clone https://github.com/val-iisc/Saddle-LongTail.git
cd Saddle-LongTail
pip install -r requirements.txt

We use Weights and Biases (wandb) to track our experiments and results. To track your experiments with wandb, create a new project with your account. The project and entity arguments in wandb.init must be changed accordingly in .py file for each experiment. To disable wandb tracking, the log_results flag can be removed.

  • Datasets

    The datasets used in the repository can be downloaded by following instructions from the following links: The CIFAR datasets are automatically downloaded to the data/ folder if it is not available.

Training

Sample command to train CIFAR-10 LT dataset with CE+DRW+SAM.

python cifar_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --loss_type LDAM --train_rule DRW --rho 0.8 --rho_schedule none --log_results --dataset cifar10 --seed 0

Sample command to train ImageNet-LT dataset with LDAM+DRW+SAM.

python imnet_train_sam.py --gpu 0 --imb_type exp --imb_factor 0.01 --data_path <Path-to-Dataset> --loss_type LDAM --train_rule DRW --dataset imagenet -b 256 --epochs 90 --arch resnet50 --cos_lr --rho_schedule step --lr 0.2 --seed 0 --rho_steps 0.05 0.1 0.5 0.5 --log_results --wd 2e-4 --margin 0.3

All the commands to reproduce the experiments are available in run.sh

Results and Checkpoints of Models

We show results on CIFAR-10 LT, CIFAR-100 LT, ImageNet-LT and iNaturalist-18 dataset. Complete results is available in the paper.

Dataset Method Accuracy Checkpoints
CIFAR-10 LT (IF=100) LDAM+DRW+SAM 81.9 ckpt
CE+DRW+SAM 80.6 ckpt
CIFAR-100 LT (IF=100) LDAM+DRW+SAM 45.4 ckpt
CE+DRW+SAM 44.6 ckpt
ImageNet-LT LDAM+DRW+SAM 53.1 ckpt
CE+DRW+SAM 47.1 ckpt
iNaturalist-18 LDAM+DRW+SAM 70.1 ckpt
CE+DRW+SAM 65.3 ckpt

Results with GLMC

We also run our method with the latest SOTA method GLMC (CVPR 2023) and demonstrate that the proposed method can further improve performance. As previously conjectured in our work, we apply SAM on the re-weighting loss of GLMC to avoid saddle points. Note that we use a $\rho$ of 0.05 for all the experiments below. The code to reproduce all the experiments in available in GLMC-2023/run.sh.

The sample command to run GLMC includes specifying additional param --rho 0.05, example command below:

python GLMC-2023/main.py --dataset cifar10 -a resnet34 --num_classes 10 --imbanlance_rate 0.02 --beta 0.5 --lr 0.01 --epochs 200 -b 64 --momentum 0.9 --weight_decay 5e-3 --resample_weighting 0.0 --label_weighting 1.2 --contrast_weight 1 --rho 0.05

Result

CIFAR-10 CIFAR-10 CIFAR-100 CIFAR-100
50 100 50 100
GLMC 89.81 87.55 62.49 57.63
GLMC + SAM 91.56 89.18 65.28 59.01

Class-Wise Hessian Analysis

We also release the code to compute the spectral density and analyse the loss landscape of the trained models.

Sample command below:

python hessian_analysis.py --gpu 0 --seed 1 --exp_str sample --resume <checkpoint_path> --dataloader_hess train --log_results

On running this command, the Eigen Spectral density of per-class loss is computed and the class-wise spectral density is plotted along with the maximum eigenvalue and the trace of the Hessian.

Overview of the arguments

Generally, all python scripts in the project take the following flags

  • -a: Architecture of the backbone. (resnet32|resnet50)
  • --dataset: Dataset (cifar10|cifar100)
  • --imb_type: Imbalance Type (Exp|Step).
  • --imb_factor: Imbalance Factor (Ratio of samples in the minority class to majority class). Default: 0.01
  • --epochs: Number of Epochs to be trained for. Default 200.
  • --loss_type: Loss Type (CE|LDAM|VS)
  • --gpu: GPU id to use.
  • --rho: $\rho$ value in SAM (Applicable to SAM runs).

Acknowledgement

Our implementation is based on the LDAM and VS-Loss. We use the PyTorch implementation of SAM from https://github.com/davda54/sam. We refer to PyHessian for computation of the Eigen Spectral density and the loss landscape analysis. We thank the authors for releasing their source-code publicly.

The implementation of GLMC+SAM is based on GLMC codebase. We thank the authors for publicly releasing the code.

Citation

If you find our paper or codebase useful, please consider citing us as:

@inproceedings{
rangwani2022escaping,
title={Escaping Saddle Points for Effective Generalization on Class-Imbalanced Data},
author={Harsh Rangwani and Sumukh K Aithal and Mayank Mishra and Venkatesh Babu Radhakrishnan},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=9DYKrsFSU2}
}

saddle-longtail's People

Contributors

mmayank74567 avatar rangwani-harsh avatar sumukhaithal6 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.