Giter Club home page Giter Club logo

mingpt-flax's Introduction

minGPT-flax

A basic transformer implementation, for seq2seq modeling in Flax/JAX. Written for educational purposes ๐Ÿซ.

Also includes some bells and whistles:

  • Data parallelism. By default, we train on all available GPUs. This can massively speed up training, even on smaller batch sizes.
  • Chunked self-attention, adapted from the approach described by Markus Rabe and Charles Staats [1]. This makes a small runtime tradeoff to avoid the quadratic memory constraint of standard self-attention.

[1] https://arxiv.org/pdf/2112.05682v2.pdf

Usage

Install:

pip install -r requirements.txt

Train:

$ python train_char.py --help

usage: train_char.py [-h] --dataset-path PATH [--experiment-name STR] [--restore-checkpoint] [--max-epochs INT]
                     [--minibatch-size INT] [--block-size INT] [--gpt-config.vocab-size INT] [--gpt-config.block-size INT]
                     [--gpt-config.n-head INT] [--gpt-config.resid-pdrop FLOAT] [--gpt-config.attn-pdrop FLOAT]
                     [--gpt-config.chunk-attention] [--gpt-config.q-chunk-size INT] [--gpt-config.kv-chunk-size INT]
                     [--gpt-config.n-layer INT] [--gpt-config.embd-dim INT] [--gpt-config.embd-pdrop FLOAT]
                     [--optimizer-config.learning-rate FLOAT] [--optimizer-config.no-lr-decay]
                     [--optimizer-config.adam-b1 FLOAT] [--optimizer-config.adam-b2 FLOAT]
                     [--optimizer-config.warmup-tokens INT] [--optimizer-config.final-tokens INT]
                     [--optimizer-config.weight-decay FLOAT] [--optimizer-config.grad-norm-clip FLOAT]

required arguments:
  --dataset-path PATH   Path to a text file, to be loaded for training. Needs to fit in memory.

optional arguments:
  -h, --help            show this help message and exit
  --experiment-name STR
                        (default: char_2022-01-07-18:01:54)
  --restore-checkpoint
  --max-epochs INT      (default: 1000)
  --minibatch-size INT  (default: 128)
  --block-size INT      (default: 128)
  --gpt-config.vocab-size INT
                        (default: 256)
  --gpt-config.block-size INT
                        The history/context length of our sequence model. (default: 128)
  --gpt-config.n-head INT
                        Output size for multi-headed self-attention. (default: 8)
  --gpt-config.resid-pdrop FLOAT
                        Dropout probability. (default: 0.1)
  --gpt-config.attn-pdrop FLOAT
                        Dropout probability. (default: 0.1)
  --gpt-config.chunk-attention
                        Enable attention chunking to trade runtime for memory efficiency. We implement an
                        approach similar to the algorithm presented here:
                        https://arxiv.org/pdf/2112.05682v2.pdf

                        If chunking is enabled, both q_chunk_size and kv_chunk_size must be set.
                        Note that `block_size % chunk_size` must be 0 for both chunk sizes.
  --gpt-config.q-chunk-size INT
                        (default: None)
  --gpt-config.kv-chunk-size INT
                        (default: None)
  --gpt-config.n-layer INT
                        (default: 8)
  --gpt-config.embd-dim INT
                        (default: 512)
  --gpt-config.embd-pdrop FLOAT
                        Dropout probability. (default: 0.1)
  --optimizer-config.learning-rate FLOAT
                        (default: 0.0006)
  --optimizer-config.no-lr-decay
                        If decay is enabled, we use cosine annealing.
  --optimizer-config.adam-b1 FLOAT
                        (default: 0.9)
  --optimizer-config.adam-b2 FLOAT
                        (default: 0.95)
  --optimizer-config.warmup-tokens INT
                        Tokens before reaching full learning rate. (default: 10240)
  --optimizer-config.final-tokens INT
                        At what point we reach 10% of original LR (default: 2560)
  --optimizer-config.weight-decay FLOAT
                        L2 regularization coefficient. (default: 0.1)
  --optimizer-config.grad-norm-clip FLOAT
                        (default: 1.0)

As an example, to train with self-attention chunk sizes of 64:

$ python train_char.py --dataset-path ./some_text_file --gpt-config.chunk-attention --gpt-config.q-chunk-size 64 --gpt-config.kv-chunk-size 64

The training script will attempt to use all available GPUs; CUDA_VISIBLE_DEVICES may be helpful if this is undesired.

Eval (sampling):

$ python eval_char.py

usage: eval_char.py [-h] --experiment-name STR [--sample-steps INT] [--sample-from-top-k INT]

required arguments:
  --experiment-name STR

optional arguments:
  -h, --help            show this help message and exit
  --sample-steps INT    (default: 500)
  --sample-from-top-k INT

Links

Third-party:

This repo also serves as a testbed for a few "core infrastructure" libraries that I've been working on, including:

  • fifteen, which contains utilities for training: data loading, experiment management, etc.
  • dcargs, which is used for unifying experiment configuration with type-safe argument parsing.
  • jax_dataclasses, which is used to construct type-safe PyTree structures.

To-do list

  • Model
  • Training boilerplate
    • Minimal training loop
    • Weight decay masking
    • Learning rate scheduling
    • Tensorboard logging
    • Checkpointing
    • Multi-GPU support
  • Demos
    • Character-level language model
      • Training script
        python train_char.py --help
      • Sampling/rollout demo
        python eval_char.py --help
    • BPE-based language model
  • Tangentially related reach goals
    • Vision transformer
    • Transformer w/ Perceiver IO-style latent vectors?
    • Weight loading from OpenAI's released model?

mingpt-flax's People

Contributors

brentyi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

wuphilipp

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.