Giter Club home page Giter Club logo

mental-sim's Introduction

Models of Mental Simulation

This repo contains pretrained PyTorch models that are optimized to predict the future state of their environment.

As these models are stimulus-computable, they can be applied to any new video for further neural and behavioral comparisons of your design.

This repository is based on our paper:

Aran Nayebi, Rishi Rajalingham, Mehrdad Jazayeri, Guangyu Robert Yang

"Neural foundations of mental simulation: future prediction of latent representations on dynamic scenes"

37th Conference on Neural Information Processing Systems (NeurIPS 2023). Selected for spotlight.

Here's a video recording that explains our work a bit.

Getting started

It is recommended that you install this repo within a virtual environment (Python 3.6 recommended), and run inferences there. An example command for doing this with anaconda would be:

conda create -y -n your_env python=3.9.7 anaconda

To install this package and all of its dependecies, clone this repo on your machine and then install it via pip:

  1. git clone https://github.com/anayebi/mental-sim.git to clone the repository.
  2. cd mental-sim/
  3. conda activate your_env
  4. Run pip install -e . to install the current version. The installation above will automatically download the necessary dependencies, which includes ptutils and brainmodel_utils, which are my Python packages for training PyTorch models and extracting their features for neural and behavioral regression, respectively.

Available Pretrained Models

To get the saved checkpoints of the models, simply run this bash script:

./get_checkpoints.sh

This will automatically download them from Hugging Face to the current directory in the folder ./trained_models/. If you want a subset of the models, feel free to modify the for loop in the above bash script.

Models are named according to the convention of [architecture]_[pretraining-dataset]_[image_size], all of which are described in our paper. You can see this notebook for loading all of our pretrained models.

Some models may be better suited than others based on your needs, but we generally recommend:

  • VC-1+CTRNN/LSTM models, where the dynamics module is pretrained on either Physion or the much larger Kinetics-700 dataset. This model class reasonably matches both Mental-Pong neural and OCP behavioral benchmarks we tested.
  • R3M+CTRNN/LSTM models, where the dynamics module is pretrained on either Physion or the much larger Kinetics-700 dataset. This model class best matches the Mental-Pong neural benchmark we tested.

We also include our best Physion-pretrained FitVid, SVG, and temporally-augmented C-SWM models, for additional points of comparison involving end-to-end pixel-wise and object-slot future predictors.

Once you have loaded the PyTorch model, you can extract features according to your pipeline of choice. Note that all of the models expect 7 context frames before running the forward simulation, so be sure to provide that minimally as input! If you want a standard example of extracting model features and running behavioral regression, see here. If you want examples of extracting model features per video (where the number of frames can be different per video, so they must be processed one at a time), see here.

Training Code

Download your video pretraining dataset of choice and then run under mpmodels/model_training/:

python runner.py --config=[]

Specify the gpu_id in the config. The configs and hyperparameters we used are specified in the mpmodels/model_training/configs directory. Model architectures are implemented in the mpmodels/models/ directory.

For example, to train our VC-1+CTRNN model on the Physion dataset, you can run this command:

CUDA_VISIBLE_DEVICES=0 python runner.py --config=configs/pretrained_frozen_encoder/pfVC1_CTRNN/physion.json

Note that you will have to modify the save_prefix key in the json file to the directory that you want to save your checkpoints, as well as the train_root_path and val_root_path directories to point to where the pretraining dataset is stored.

Cite

If you used this codebase for your research, please consider citing our paper:

@inproceedings{nayebi2023neural,
  title={Neural Foundations of Mental Simulation: Future Prediction of Latent Representations on Dynamic Scenes},
  author={Nayebi, Aran and Rajalingham, Rishi and Jazayeri, Mehrdad and Yang, Guangyu Robert},
  booktitle={The 37th Conference on Neural Information Processing Systems (NeurIPS 2023)},
  url={https://arxiv.org/abs/2305.11772},
  year={2023}
}

Contact

If you have any questions or encounter issues, either submit a Github issue here or email [email protected].

mental-sim's People

Contributors

anayebi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.