huthlab / encoding-model-scaling-laws Goto Github PK
View Code? Open in Web Editor NEWRepository for the 2023 NeurIPS paper "Scaling laws for language encoding models in fMRI"
Repository for the 2023 NeurIPS paper "Scaling laws for language encoding models in fMRI"
Sorry to bother you again. I have a few more questions about reproduction.
Thank you in advance for your response!
Originally posted by @dyhan316 in #1 (comment)
I am trying to replicate the Figure 1 encoder performance plot on paper but am having difficulty.
I followed the tutorial jupyter notebook (using the 33rd layer of OPT-30B model), and tried to reproduce the results for subject S3. I was able to reproduce the Figure 2 results (voxel-wise r values). However, I was not able to reproduce the "Encoding Performance (Avg r^2)" values of Figure 1. I got values in the range of 0.02, not 0.03 as Figure 1 claims it is.
Below is what I got by using the voxel-wise r value (corrs_unnorm) to get r^2 (|r|*r), averaged over each trial. The values are different from the values in Figure 1.
(Below is Figure 1, for reference)
Could you please explain how I can reproduce the results on paper? My current assumptions are that
Thank you in advance!
Hello, I was just wondering,
How was the word times decided for each token? It appears that when you used GPT1 (in the Nature neuroscience paper) each word was a token, so the word times of each word could be directly mapped to the timing of the token (later to be resampled through Lanczos resampling).
However, this paper has models that use subtokenization (i.e. one word != one token). So I was wondering, how did you give each token a timepoint?
Thank you in advance for your answer :)
Actually, could you please provide the code you used for LLAMA?
Hello again! The response data for 'wheretheressmoke" of subject S01 seems to be averaged in an incorrect way.
In other words when I calculate tensession_resp.mean(axis=0)/avg_resp_dict['wheretheressmoke']
, the values are not 1 for sub1, while they are 1 for sub2 and 3.
#Sub1
array([[ -0.5260241 , 0.43809538, 0.72644223, ..., -0.12333434,
-0.06862524, -0.17878285],
[ 0.22232635, -66.78132633, -0.44553775, ..., -0.74112664,
-0.9559144 , 0.36624247],
[ 0.75621392, 0.79440582, 0.69228175, ..., -1.66421863,
-1.03083738, -0.93969021],
...,
[ 1.98471561, 0.2910771 , 0.55035473, ..., -0.44772529,
1.72803771, -0.8670477 ],
[ -0.14013489, -0.90228549, -0.38435461, ..., 0.51850835,
-0.40164749, 0.51553115],
[ 0.16714859, 0.11103255, -0.30769317, ..., 0.51741804,
1.05856689, 0.30922208]])
#Sub2
array([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
#Sub3
array([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]])
I found that this may be due to the fact that the Sub1's wheretheressmoke data was normalized after adding them up, as you can see in the average and std values below
#Sub1
itsabox ((355, 81126), 4.7093635045150747e-17, 0.9999999999999991)
odetostepfather ((404, 81126), -4.183271636501786e-18, 1.0)
inamoment ((205, 81126), -3.715829144698577e-17, 1.0000000000000004)
wheretheressmoke ((291, 81126), -3.1167046066065954e-17, 1.0)
onapproachtopluto ((271, 81126), -5.8329725420055375e-18, 0.5123625396215089)
fromboyhoodtofatherhood ((348, 81126), -4.8967107880959505e-18, 0.43156861943485947)
#Sub2
itsabox ((355, 94251), 3.499718323405976e-17, 1.0000000000000004)
odetostepfather ((404, 94251), -3.9085500487734813e-17, 1.0000000000000016)
inamoment ((205, 94251), -8.414075334487538e-18, 0.999999999999999)
wheretheressmoke ((291, 94251), -1.8649678783029527e-17, 0.3592392686995733)
onapproachtopluto ((271, 94251), 2.079494437168422e-17, 0.4883584127554785)
fromboyhoodtofatherhood ((348, 94251), 4.055373299929339e-17, 0.4661844447316095)
#Sub3
itsabox ((355, 95556), 1.9632392991683566e-17, 0.9999999999999999)
odetostepfather ((404, 95556), 3.3595442989632425e-17, 0.9999999999999987)
inamoment ((205, 95556), 2.608143194392676e-17, 1.0000000000000013)
wheretheressmoke ((291, 95556), 3.090872077370644e-18, 0.36932671494961317)
onapproachtopluto ((271, 95556), -2.6657210209366712e-17, 0.4865749028668901)
fromboyhoodtofatherhood ((348, 95556), -3.145461764665795e-17, 0.4530105777188318)
the code for getting the things above are attached below for reference :
#checking some stuff....
import os
import numpy as np
import joblib
story2see = ['itsabox', 'odetostepfather', 'inamoment', 'wheretheressmoke', 'onapproachtopluto', 'fromboyhoodtofatherhood']
os.chdir('/scratch/x2767a03/BRAIN_DECODING/brain_scaling_law/sample_data/wheretheressmoke_preds')
sub = "S01" #S2,S3 : works
avg_resp_dict = joblib.load(f'UT{sub}_responses.jbl')
tensession_resp = joblib.load(f'tensessions_wheretheressmoke_{sub}.jbl')
os.listdir()
stories = avg_resp_dict.keys()
def get_stats(arr):
arr = np.nan_to_num(arr)
return arr.shape, np.mean(arr), np.std(arr) #, np.min(arr), np.max(arr)
for story in story2see :
print(story, get_stats(avg_resp_dict[story]))
for trial in range(10):
print(f'trial {trial}', get_stats(tensession_resp[trial]))
print(tensession_resp.mean(axis=0)/avg_resp_dict['wheretheressmoke'])
from sklearn.metrics import r2_score
print(r2_score(tensession_resp.mean(axis=0), avg_resp_dict['wheretheressmoke']), avg_resp_dict['wheretheressmoke'].std(axis=0))
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.