Welcome to this repository of deep learning models written in JAX and Flax! JAX is a numerical computing library capable of executing on various hardware accelerators, including CPUs, GPUs, and TPUs. Flax is built on top of JAX and provides a flexible way to train machine learning models.
This repository contains a collection of deep learning models, such as multilayer perceptrons, convolutional neural networks, and autoencoders. The training pipelines of these models are demonstrated on Google Colab.
This repository is inspired by Sebastian Raschka's Deep Learning Model Zoo, which is written in PyTorch and Tensorflow.
Title | Dataset | Notebooks |
---|---|---|
Basic MLP | MNIST |
Title | Dataset | Notebooks |
---|---|---|
Basic ConvNet | MNIST | |
Basic ConvNet | CIFAR-10 | |
Basic ConvNet with dropout | CIFAR-10 | |
Basic ConvNet with batchnorm | CIFAR-10 | |
ResNet | CIFAR-10 |
Title | Dataset | Notebooks |
---|---|---|
MLP autoencoder | MNIST | |
Conv autoencoder | MNIST | |
Variational MLP autoencoder | MNIST | |
Variational Conv autoencoder | MNIST |
This repository includes code that has been adapted from various sources, including the Flax examples, the UvA DL tutorials, and the JAXopt examples.
All notebooks in this repository are written for didactic purposes and are not intended to serve as performance benchmarks.