Giter Club home page Giter Club logo

loss-functions's Introduction

Loss functions for imbalanced classification and/or where Cohen's kappa is the metric

This repository contains the two loss functions that were created during the development of

Adam M. Jones, Laurent Itti, Bhavin R. Sheth, "Expert-level sleep staging using an electrocardiography-only feed-forward neural network," Computers in Biology and Medicine, 2024, doi: 10.1016/j.compbiomed.2024.108545

(Link to paper: https://doi.org/10.1016/j.compbiomed.2024.108545.)

If you find this repository helpful, please cite our work.

The main repository for the paper is here: https://github.com/adammj/ecg-sleep-staging


Motivation

These loss functions were designed for imbalanced classification problems, where it is not possible to oversample the minority classes or undersample the majority classes (please see the paper for a more thorough explanation of this situation). Furthermore, most classification problems assume that accuracy is the desired metric, and therefore use cross-entropy as the loss function. However, for our use-case, Cohen's kappa is the correct metric (which is only loosely correlated with accuracy).

Normally, especially in highly imbalanced data, correctly classifying the majority class(es) will almost always be at the expense of the minority class(es). However, we use the geometric mean of the individual class performances (kappa, TPR, or PPV), which are all in the range of [0, 1] (see below for kappa). This has the effect of causing the loss function to balance competing ratios, instead of competing counts (which will always disfavor the minority).

Both loss functions assume that the final operation of the network is a softmax, which transforms the output into a probability for each of the N classes.

  1. GeomeanKappa: Geometric Mean of Kappas (used in paper).

    This calculates the geometric mean of each of the class-wise kappas. The class-wise kappas are scaled using (1 + k)/2, so that the default kappa range is transformed from [-1, 1] to [0, 1].

    By doing so, it will tend to improve all of the class-wise kappas.

  2. GeomeanTPRPPV: Geometric Mean of TPR and PPV.

    This calculates the geometric mean of the True Positive Rates (TPR or sensitivity) and Positive Predictive Values (PPV or precision) for each of the classes.

    By doing so, it will tend to both increase the TPR (number of correctly classified instances divided by the total possible instances for reviewer 1 (rows)), and PPV (number of correctly classified instances, divided by all instances that reviewer 2 (columns) used the same class). For example, with imbalanced classes, it will simultaneously work to correctly classify as many of the majority class as possible (minimizing off-diagonal counts in the rows), while minimizing the number of incorrect classifications that occur against the minority class (minimizing off-diagonal counts in the columns).


Comparisons against other loss functions

For our final model, we substituted in several different loss functions in order to compare them against our loss function. We'd like to highlight that for the two functions where Overall kappa is slightly higher (+1%), their minority class (N1) performance is significantly worse (-27%).

The table gives the kappa for each sleep stage and loss function pair.

Loss function Overall Wake N1 N2 N3 REM
Geometric Mean of Kappas (ours) 0.726 0.862 0.373 0.671 0.703 0.805
Cross-entropy 0.734 0.867 0.274 0.682 0.699 0.805
Cross-entropy (weighted) 0.669 0.845 0.332 0.583 0.677 0.786
Focal loss 0.732 0.862 0.297 0.679 0.703 0.801
Cohen’s kappa (overall) 0.720 0.854 0.000 0.669 0.697 0.795
Ratio of ours to best 99% 99% 100% 98% 100% 100%

Additional details

The GeomeanTPRPPV was used for a significant fraction of the hyperparameter search, and performed quite well. However, once I figured out how to calculate the class-wise kappas using a simple equation, I switched to GeomeanKappa. This is because, mathematically, it should be a little closer to the desired metric, Cohen's kappa (which is the weighted average of the class-wise kappas).

The calculate_loss is a separate function, and the loss_confusion matrix is stored, to aid some calculations that are done elsewhere in my training code. However, the loss function is a drop-in replacement for any other PyTorch loss function.


MIT License

Copyright (C) 2024 Adam M. Jones

loss-functions's People

Contributors

adammj avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar Kostas Georgiou avatar

Forkers

eegkit

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.