Giter Club home page Giter Club logo

efficient_cross_entropy's Introduction

Memory-Efficient Cross Entropy Loss

TL;DR

This repo contains an implementation of a linear projection + cross-entropy loss PyTorch module that has substantially lower memory consumption compared to a standard implementation, with almost no additional compute cost. The memory savings come from two optimizations: 1) overwriting the logits with their gradients in-place and 2) not materializing the entire logits tensor.

Overview

In networks trained to perform classification tasks, such as language models, the final layer is generally a linear projection from dim channels to n_classes channels to compute the logits, followed by cross-entropy loss. When n_classes is large relative to dim, the logits consume a large amount of GPU memory compared to other activations in the network. For example, Mistral 7B has a vocabulary size of 32,000 compared to much lower hidden dimension of 4096, so the logits take up roughly 8x as much GPU memory as the preceding activations.

This repo contains two optimizations to reduce the memory usage of a linear projection followed by cross-entropy loss, implemented in PyTorch + Triton. These optimizations primarily focus on reducing the memory usage of the logits tensor and its gradient since these tensors can dominate overall memory usage:

  • Optimization 1: Overwrite the logits tensor with its gradient in-place to avoid allocating more memory for the gradients
  • Optimization 2: Compute the loss and gradients in a loop of $K$ micro-batches in the forward pass so that we only materialize $\frac{1}{K}$ of the full logits tensor

These optimizations can reduce peak memory usage of a linear projection + cross-entropy loss by several times with almost no additional compute cost.

Performance Analysis

Figure 1 plots the peak memory usage (top row) and median wall clock time (bottom row) before and after applying these optimizations on an A100 GPU.

Figure 1 Figure 1 (generated by running python ./benchmark.py)

In the top row, the memory is compared to $4(2ND + 2DV + nNV)$, where 4 is the number of bytes in a float32, N=n_tokens, D=dim, and V=n_classes. The three terms account for the input tensor, the linear parameter weights, and the logits. The factors of two account for also storing the gradients. Four lines are plotted with $n=0,1,2,3$. The memory-efficient triton implemtation doesn't materialize the logits, and thus correspondes to $n=0$. The PyTorch memory usage seems to vary between $n=3$ for smaller $D$ and $n=1$ for larger $D$. Maybe for larger $D$, PyTorch more aggressively frees the memory used by the logits and logit gradients (and perhaps also the logit exponentials) during the backwards pass.

PyTorch thus uses roughly $\frac{2ND + 2DV + nNV}{2ND + 2DV} = 1 + \frac{n}{2} \left( \frac{D}{V} + \frac{D}{N} \right)^{-1}$ times more memory than the memory-efficient triton implemtation, with $1 \leq n \leq 3$. This is an asymptotic improvement when $D \ll V$ and $D \ll N$.

Optimization 1: Overwrite logits with their gradients

During the backward pass of a linear projection + cross-entropy loss module, we no longer need to keep the logits in memory after computing their gradients. So, we overwrite the logits in-place with their gradients in the backward pass to avoid allocating any new memory for the gradients.

The memory savings from this optimization are (approximately) represented by the difference between the blue line (without this optimization) and orange line (with this optimization) in Figure 1, above.

Optimization 2: Avoid materializing full logits tensor

To avoid materializing the full logits tensor, we split the batch into $K$ micro-batches. Then, we can compute both the loss and the logit gradients (up to a scale factor) during the forward pass. With the logit gradients, we can compute the gradients w.r.t the input features and linear projection weights. This way, we do not need to materialize the entire logits tensor - we materialize $\frac{1}{K}$ of the logits at a time, compute the gradients we need, then discard those logits. Note that this requires no additional recomputation since this is all done in the forward pass.

The reason we can compute the logit gradients in the forward pass is that the output of this module is a scalar (since we assume we will do either a mean or sum reduction on the loss). Therefore, to get the correct gradients in the backward pass, we can simply multiply the gradients we computed in the forward pass by the grad_output scalar.

The top row in Figure 1 shows the memory savings from this optimization for different values of $K$ (n_loop_iters in Figure 1 refers to the number of microbatches $K$). In the bottom row of Figure 1, we can see that the median wall clock time is nearly identical before vs after these optimizations.

Note that we see diminishing returns in peak memory usage as we scale the hidden dim (dim). This is because peak memory usage becomes determined by the size of the linear projection's weights & gradients, rather than the logits, once the hidden dim is sufficiently large (right column in Figure 1).

efficient_cross_entropy's People

Contributors

kjslag avatar mgmalek avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.