This repository contains the code for my MEng thesis, To attract or to oscillate: Validating dynamics with behavior. All code is written in Python, and recurrent neural network (RNN) models are written and trained in JAX, Flax, and Optax.
RNNs can learn two distinct dynamical systems to compute modular arithmetic, specifically
In the results/
and scripts/
folders, there are two subfolders:
development/
- Jupyter notebooks used for the development of thesrc/
python packageexperiments/
- Jupyter notebooks used to train models and generate the results discussed in the thesis
JAX is an incredibly powerful Python package. In the context of training RNNs, JAX's scan
function is significantly faster than using a for
loop. I was able to train 16,128 RNNs on the MIT SuperCloud HPC in about 60 hours. Given that training and analyzing RNNs has become routine in computational neuroscience, I'm betting that the field will shift to using JAX, Flax, and Optax.