alan-cooney / circuitsvis Goto Github PK
View Code? Open in Web Editor NEWMechanistic Interpretability Visualizations using React
Home Page: https://alan-cooney.github.io/CircuitsVis/
License: MIT License
Mechanistic Interpretability Visualizations using React
Home Page: https://alan-cooney.github.io/CircuitsVis/
License: MIT License
Suggestion: Have render raise an error if the user inputs an argument that is not on the typescript render
This has bitten me a few times, when I mispell something or forget the name of an argument to the Javascript function. Very hard to debug, especially from within VSCode! The Javascript part works, but the interface with Python is broken.
In the latest version, as I understand it, I need to do import circuitsvis.attention
to get the attention pattern visualizer, etc for each visualization. Previously, I could just do import circuitsvis
to get all of them. Is this intentional? Feels like a worse user experience IMO.
Example use case: we have 3 tokens at the end of a prompt, and we want to see the attention probs from those back to all other tokens in the sequence. This could be done via something like
cv.attention.attention_patterns(
attention = attention,
src_tokens = tokens,
dest_tokens = tokens[-3:],
)
Not sure how difficult this would be to implement.
When building the dev container locally (Mac M1), I get:
RuntimeError
Unable to find installation candidates for nvidia-cuda-nvrtc-cu11 (11.7.99)
at ~/.local/share/pypoetry/venv/lib/python3.10/site-packages/poetry/installation/chooser.py:103 in choose_for
99│
100│ links.append(link)
101│
102│ if not links:
→ 103│ raise RuntimeError(f"Unable to find installation candidates for {package}")
104│
105│ # Get the best link
106│ chosen = max(links, key=lambda link: self._sort_key(package, link))
107│
• Installing nvidia-cuda-runtime-cu11 (11.7.99): Failed
I'm guessing the build works on the github codespace servers with GPUs, but it would be nice not to require a GPU to build.
Currently, attention plots always mask the upper triangular region.
This behavior makes sense for models using causal attention, which seems to be most models. But I've encountered a few models that use bidirectional attention (for example, the toy model here), and I think it would be helpful to be able to visualize these attention patterns using the same tooling.
A simple solution is to add an optional boolean flag to the attention_pattern
and attention_heads
functions that toggles whether or not to mask the upper triangular region. It could be called mask_upper_tri
and would of course default to true
.
Can implement quickly if other folks think it'd be useful.
When developing a new feature, there's an annoying amount of boiler plate I need to write - creating mock data in a json format, creating a stories.tsx file, writing the Props and function definition, and writing the Python function. I think this can mostly be automated, and makes it lower friction to add features. Eg adding a generate_stub_files
function to utils
. Here's an attempt:
# Split the string into a list of words
words = s.split('_')
# Capitalize the first letter of each word and join them
camel = ''.join([word.capitalize() for word in words])
# Return the first letter in lowercase
return camel[0].lower() + camel[1:]
def snake_to_pascal(s: str) -> str:
# Split the string into a list of words
words = s.split('_')
# Capitalize the first letter of each word and join them
pascal = ''.join([word.capitalize() for word in words])
return pascal
# %%
# Copy and paste to mock file
mock_data = {
"prompt": prompt,
"top_k_log_probs": top_log_probs.tolist(),
"top_k_tokens": top_tokens,
"correct_tokens": correct_tokens,
"correct_token_rank": correct_token_rank,
"correct_token_log_prob": correct_token_log_prob.tolist(),
}
vis_name = "LogProbVis"
mock_types = {
"prompt": "string[]",
"top_k_log_probs": "number[][]",
"top_k_tokens": "string[][]",
"correct_tokens": "string[]",
"correct_token_rank": "number[]",
"correct_token_log_prob": "number[]",
}
print(mock_data)
# %%
s = []
for name in mock_data:
data = mock_data[name]
typ = mock_types[name]
if isinstance(data, torch.Tensor):
data = data.tolist()
print(f"export const {snake_to_camel('mock_'+name)}: {typ} = {data};")
print()
# %%
newline = "\n"
template = f"""import {{ ComponentStory, ComponentMeta }} from "@storybook/react";
import React from "react";
import {{ {", ".join(map(lambda name: snake_to_camel("mock_" + name), mock_data.keys()))} }} from "./mocks/{vis_name[0].lower() + vis_name[1:]}";
import { {vis_name} } from "./{vis_name}";
export default {{
component: {vis_name}
}} as ComponentMeta<typeof {vis_name}>;
const Template: ComponentStory<typeof {vis_name}> = (args) => (
<{vis_name} {{...args}} />
);
export const SmallModelExample = Template.bind({{}});
SmallModelExample.args = {{
{f",{newline} ".join([f"{snake_to_camel(name)}: {snake_to_camel('mock_'+name)}" for name in mock_data])}
}};
"""
print(template)
# %%
func_defn = f"""
export function {vis_name}({{
{f",{newline} ".join(map(snake_to_camel, mock_data.keys()))}
}}: {vis_name}Props) {{
"""
interface = f"""
export interface {vis_name}Props {{
{''';
/**
*/
'''.join([f"{snake_to_camel(name)}: {mock_types[name]}" for name in mock_types])}
}}
"""
print(func_defn)
print()
print()
print()
print(interface)```
When trying to build the dev container through docker desktop on MacOS with M1, I get:
• Installing ipython (7.34.0)
• Installing jupyter-server (1.23.0)
• Installing psutil (5.9.4): Failed
RuntimeError
Unable to find installation candidates for psutil (5.9.4)
at /usr/local/lib/python3.10/site-packages/poetry/installation/chooser.py:103 in choose_for
99│
100│ links.append(link)
101│
102│ if not links:
→ 103│ raise RuntimeError(f"Unable to find installation candidates for {package}")
104│
105│ # Get the best link
106│ chosen = max(links, key=lambda link: self._sort_key(package, link))
107│
[18769 ms] postCreateCommand failed with exit code 1. Skipping any further user-provided commands.
Done. Press any key to close the terminal.
Not sure of cause of this since psutil installs fine with pip. Temp fix is to export to a requirements.txt and use pip :):
poetry export --without-hashes --format=requirements.txt > requirements.txt && pip install -r requirements.txt
A logit visualizer, which shows text coloured by the prob/log-prob predicted for that token, and if you hover over a token it shows the top 5 logits/probs/log-probs (I was planning on trying to hack this together with your library this evening, thus my questions)
When I define a CircuitsVis component, it can take several seconds as it generates local_src, and especially running a bundle_source command. I often just want the cdn_src when I am not actively developing CircuitsVis (but using it elsewhere), but the components don't let me explicitly turn this off.
I may fix this myself at some point, but assigning to @alan-cooney in case there's a principled way of doing this - IIRC you used to have it another way and changed your mind?
When I run any circuits vis render (eg circuitsvis.examples.hello("Neel")) I get the following error
---------------------------------------------------------------------------
CalledProcessError Traceback (most recent call last)
/tmp/ipykernel_1001422/1929088876.py in <module>
1 import circuitsvis.examples
----> 2 circuitsvis.examples.hello("Help")
~/CircuitsVis/python/circuitsvis/examples.py in hello(name)
16 return render(
17 "Hello",
---> 18 name=name,
19 )
~/CircuitsVis/python/circuitsvis/utils/render.py in render(react_element_name, **kwargs)
170 Html: HTML for the visualization
171 """
--> 172 local_src = render_local(react_element_name, **kwargs)
173 cdn_src = render_cdn(react_element_name, **kwargs)
174 return RenderedHTML(local_src, cdn_src)
~/CircuitsVis/python/circuitsvis/utils/render.py in render_local(react_element_name, **kwargs)
103 if REACT_DIR.exists():
104 install_if_necessary()
--> 105 bundle_source()
106
107 # Load the JS
~/CircuitsVis/python/circuitsvis/utils/render.py in bundle_source(dev_mode)
81 capture_output=True,
82 text=True,
---> 83 check=True
84 )
85
/opt/conda/lib/python3.7/subprocess.py in run(input, capture_output, timeout, check, *popenargs, **kwargs)
510 if check and retcode:
511 raise CalledProcessError(retcode, process.args,
--> 512 output=stdout, stderr=stderr)
513 return CompletedProcess(process.args, retcode, stdout, stderr)
514
CalledProcessError: Command '['yarn', 'buildBrowser', '--dev']' returned non-zero exit status 1.
When coding locally, html.local_src
injects a very large amount of code each time I render an HTML. This mostly comes from injecting the source code of underlying libraries rather than a CDN (eg tensorflow.js). This makes my notebooks much larger, and makes the visualizations take several seconds to load. I'm always coding online, so I'd like to have the option to have the CDN for source code of libraries, but to put in the source code for the local code that I am using.
Google Colaboratory currently has torch 2.0.0+cu118 installed by default. !pip install circuitsvis
triggers a downgrade of torch to 1.13.1, which in turn downgrades nvidia_cuda and many other packages and slows setup considerably.
Is there much involved in supporting the later torch version - possibly dependent on the python version like numpy?
Hey I am trying to understanding the attention pattern visualization, is the destination axis input text pass to encoder and source axis for decoder?
Version 1.38.0:
str(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))
produces HTML containing https://unpkg.com/[email protected]/dist/cdn/esm.js
. This URL returns HTTP 500 Internal Server Error
as of today (I think it was working yesterday).
Variants of that URL, e.g. https://unpkg.com/circuitsvis
or https://unpkg.com/[email protected]/dist/cdn/esm.js
, also return an Internal Server Error
.
Other unpkg links work, e.g. https://unpkg.com/[email protected]/umd/react.production.min.js
.
(I love this library by the way!)
convert_props({"a": np.float32(1.)})
gives the error TypeError: Object of type float32 is not JSON serializable
This is an issue if you take eg max_value=array.max()
for some numpy array, as this returns a np.float32
object not float
. This can likely be fixed with a quick hack to cast types like this to a float. I may add this myself at some point.
I'm looking at https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp / https://colab.research.google.com/drive/1w9zCWpE7xd1sDuMT_rsjARfFozeWiKF4, in particular the "Visualising Attention Heads" section, with the code
print(type(gpt2_cache))
attention_pattern = gpt2_cache["pattern", 0, "attn"]
print(attention_pattern.shape)
gpt2_str_tokens = gpt2_small.to_str_tokens(gpt2_text)
print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(
tokens=gpt2_str_tokens,
attention=attention_pattern,
attention_head_names=[f"L0H{i}" for i in range(12)],
))
And it seems like once I click on a head and/or token to lock the focus, there's no way to unlock the focus and get back the averaged value. There should be a way to do this, and the visualization should signpost this.
(Also, "Tokens (click to focus)" should probably be "Tokens (hover to focus, click to lock)" much like "Head selector (hover to focus, click to lock)")
Weird shrinkage of cv.attention.attention_heads
plots.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.