Giter Club home page Giter Club logo

sparse-probing-paper's Introduction

sparse-probing

Code repository for Finding Neurons in a Haystack: Case Studies with Sparse Probing

Pardon our mess. The basic core of sparse probing can be implemented very easily with just sklearn applied to a dataset of activations acquired with raw Pytorch hooks or TransformerLens. This repository is almost all experimental infrastructure and analysis specific to our set up of datasets and compute (slurm).

See this repository for a minimal replication of finding context neurons.

Organization

We expect most people to simply be interested in a large list of relevant neurons, available as CSVs within interpretable_neurons/. Note these are for the Pythia V0 models, which have since been updated on HuggingFace.

Our top level scripts for saving activations and running probing experiments can be count in get_activations.py and probing_experiment.py. All of command line argument configurations can be viewed in the experiments/ directory, which contain all of the slurm scripts we used to run our experiments.

probing_datasets/ contain the modules required to make and prepare all of our feature datasets. We recommend simply downloading them from dropbox.

Analysis and plotting code is distributed within individual notebooks and analysis/.

Instructions for reproducing

Note that our full experiments generate well over 1 TB of data and require substantial GPU and CPU time.

Getting started

Create virtual environment and install required packages

git clone https://github.com/wesg52/sparse-probing-paper.git
cd sparse-probing
pip install virtualenv
python -m venv sparprob
source sparprob/bin/activate
pip install -r requirements.txt

Acquire Gurobi license. Free for academics. Make sure you are on campus wifi (you may also need to seperately install grbgetkey).

Environment variables

To enable running our code in many different environments we use environemnt variables to specify the paths for all data input and output. For examples

export RESULTS_DIR=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/results
export FEATURE_DATASET_DIR=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/feature_datasets
export TRANSFORMERS_CACHE=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads
export HF_DATASETS_CACHE=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads
export HF_HOME=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads

Cite us

If you found our work helpful, please cite our paper:

@article{gurnee2023finding,
  title={Finding Neurons in a Haystack: Case Studies with Sparse Probing},
  author={Gurnee, Wes and Nanda, Neel and Pauly, Matthew and Harvey, Katherine and Troitskii, Dmitrii and Bertsimas, Dimitris},
  journal={arXiv preprint arXiv:2305.01610},
  year={2023}
}

sparse-probing-paper's People

Contributors

wesg52 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

sparse-probing-paper's Issues

Why are all OSP results not available?

Why can I not run OSP for --osp_upto_k values greater than 16? Whenever I set the values to some high numbers such as 200, I get the following error:

Traceback (most recent call last):
  File "probe_adapters_exp.py", line 260, in <module>
    run_probe_on_layer(
  File "probe_adapters_exp.py", line 106, in run_probe_on_layer
    result = inner_loop_fn(
  File "/data/users/jalabi/adaptation/sparse-probing-paper/experiments/inner_loops.py", line 178, in optimal_sparse_probing
    model_stats, filtered_support, beta, bias = sparse_classification_oa(
  File "/data/users/jalabi/adaptation/sparse-probing-paper/experiments/probes.py", line 133, in sparse_classification_oa
    support_indices = sorted([i for i in range(len(s)) if s[i].X > 0.5])
  File "/data/users/jalabi/adaptation/sparse-probing-paper/experiments/probes.py", line 133, in <listcomp>
    support_indices = sorted([i for i in range(len(s)) if s[i].X > 0.5])
  File "src/gurobipy/var.pxi", line 125, in gurobipy.Var.__getattr__
  File "src/gurobipy/var.pxi", line 153, in gurobipy.Var.getAttr
  File "src/gurobipy/attrutil.pxi", line 100, in gurobipy.__getattr
AttributeError: Unable to retrieve attribute 'X'

@wesg52, please do you by chance know what the issue is ?

replicate experiments on the LLama model.

Hello, thank you for sharing your code. Now I want to replicate your experiments on the LLama model, mainly to experiment on the compound_words dataset. However, the data you provided should be based on the Pythia model, so I need to rebuild it for the LLama dataset. I noticed that there is a file called make_feature_datasets.py, but when I tried to run this file to build the data, I found that the file "PILE_TEST_PATH = '/home/gridsan/groups/maia_mechint/datasets/pile-test.hf'" is missing. Can you please provide this file? Or if I need to replicate your experiments on LLama, do you have any suggestions?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.