Giter Club home page Giter Club logo

awesome-jax's Introduction

Awesome JAX AwesomeJAX Logo

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Contents

  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity.
    • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax - Has an object oriented design similar to PyTorch.
    • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
    • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
    • Jraph - Lightweight graph neural network library.
    • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
    • HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
    • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
    • Scenic - A Jax Library for Computer Vision Research and Beyond.
  • NumPyro - Probabilistic programming based on the Pyro library.
  • Chex - Utilities to write and test reliable JAX code.
  • Optax - Gradient processing and optimization library.
  • RLax - Library for implementing reinforcement learning agents.
  • JAX, M.D. - Accelerated, differential molecular dynamics.
  • Coax - Turn RL papers into code, the easy way.
  • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers - Construct differentiable convex optimization layers.
  • TensorLy - Tensor learning made simple.
  • NetKet - Machine Learning toolbox for Quantum Physics.
  • Fortuna - AWS library for Uncertainty Quantification in Deep Learning.
  • BlackJAX - Library of samplers for JAX.

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku.
    • Equivariant MLP - Construct equivariant neural network layers.
    • jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
    • Parallax - Immutable Torch Modules for JAX.
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications.
  • jax-flows - Normalizing flows in JAX.
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX.
  • jax-cosmo - Differentiable cosmology library.
  • efax - Exponential Families in JAX.
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
  • imax - Image augmentations and transformations.
  • FlaxVision - Flax version of TorchVision.
  • Oryx - Probabilistic programming language based on program transformations.
  • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV - A photovoltaic simulator with automatic differentation.
  • jaxlie - Lie theory library for rigid body transformations and optimization.
  • BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
  • flaxmodels - Pretrained models for Jax/Flax.
  • CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
  • exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
  • JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
  • PIX - PIX is an image processing library in JAX, for JAX.
  • bayex - Bayesian Optimization powered by JAX.
  • JaxDF - Framework for differentiable simulators with arbitrary discretizations.
  • tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
  • jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
  • PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
  • EvoJAX - Hardware-Accelerated Neuroevolution
  • evosax - JAX-Based Evolution Strategies
  • SymJAX - Symbolic CPU/GPU/TPU programming.
  • mcx - Express & compile probabilistic programs for performant inference.
  • Einshape - DSL-based reshaping library for JAX and other frameworks.
  • ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
  • Diffrax - Numerical differential equation solvers in JAX.
  • tinygp - The tiniest of Gaussian process libraries in JAX.
  • gymnax - Reinforcement Learning Environments with the well-known gym API.
  • Mctx - Monte Carlo tree search algorithms in native JAX.
  • KFAC-JAX - Second Order Optimization with Approximate Curvature for NNs.
  • TF2JAX - Convert functions/graphs to JAX functions.
  • jwave - A library for differentiable acoustic simulations
  • GPJax - Gaussian processes in JAX.
  • Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
  • Eqxvision - Equinox version of Torchvision.
  • JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
  • econpizza - Solve macroeconomic models with hetereogeneous agents using JAX.
  • SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
  • jax-tqdm - Add a tqdm progress bar to JAX scans and loops.
  • safejax - Serialize JAX, Flax, Haiku, or Objax model params with πŸ€—safetensors.
  • Kernex - Differentiable stencil decorators in JAX.
  • MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
  • Pax - A Jax-based machine learning framework for training large scale models.
  • Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects.
  • purejaxrl - Vectorisable, end-to-end RL algorithms in JAX.
  • Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
  • SCICO - Scientific computational imaging in JAX.
  • Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
  • BrainPy - Brain Dynamics Programming in Python.
  • OTT-JAX - Optimal transport tools in JAX.
  • QDax - Quality Diversity optimization in Jax.
  • JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
  • Pgx - Vectorized board game environments for RL with an AlphaZero example.
  • XLB - A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.

JAX

Flax

Haiku

Trax

  • Reformer - Implementation of the Reformer (efficient transformer) architecture.

NumPyro

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

  • Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.

Contributing

Contributions welcome! Read the contribution guidelines first.

awesome-jax's People

Contributors

8bitmp3 avatar asem000 avatar astanziola avatar bwohlberg avatar chaoming0625 avatar clement-bonnet avatar dominikstrb avatar ericmjl avatar gboehl avatar jejjohnson avatar jeremiecoullon avatar josephrocca avatar julioasotodv avatar lev1ty avatar matpalm avatar matthias-wright avatar mehdiataei avatar michalk8 avatar mplemay avatar mtthss avatar murphyk avatar myagues avatar n2cholas avatar neilgirdhar avatar nishanthjkumar avatar philipvinc avatar phlippe avatar sauravmaheshkar avatar silencelamb avatar stjepanjurekovic avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.