Comments (7)
I think it would be good if num_samples acts the same for both implementations.
Yes, i think that's a great point and the two implementations should have the same behavior. I think I'd tend towards having num_samples
represent the total number of gradient calls the explainer is doing, so I'll probably change the PathExplainerTorch
implementation to more closely match the _sample_alphas
method of PathExplainerTF
from path_explain.
Great -- apparently there's some issue with torch v.1.6 and transformers v4.8, now that I've upgraded to torch v.1.9 the gradients appear to be of the correct magnitude and consistent with what I've seen in tensorflow (all our NLP experiments were initially done in TF, which is why we never caught this).
I think you're correct that for the sake of consistency, the "+1" should be removed from the PyTorch explainer so that it instead looks like this:
scaled_inputs = [reference_tensor + (float(i)/num_samples)*(input_expand - reference_tensor) \ for i in range(0,num_samples)]
Again, feel free to make a pull request if you'd like, and nice catch!
from path_explain.
Would you be willing to send me a snippet of the code for the HuggingFace model you've been looking at in TF and PyTorch? I was planning to do some maintenance on the explainers this week, and that seems like a very good test for ensuring consistency of behavior between the two explainers
from path_explain.
Yes, no problem at all, I'll quickly set up a code snippet
from path_explain.
I've been testing your code on a causal language model, DistilGPT-2, on a simple input sentence ("It's raining cats and dogs"). The attributions that are computed here are for "dogs" based on the prefix "It's raining cats and".
Torch
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from path_explain import PathExplainerTorch, text_plot
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
def embedding_model(batch_ids):
batch_embedding = model.transformer.wte.weight[[batch_ids]]
return batch_embedding
def prediction_model(batch_embedding):
logits = model(inputs_embeds=batch_embedding).logits[:, -1]
probs = torch.log_softmax(logits, dim=-1)
return probs
sen = "It is raining cats and"
batch_ids = tokenizer([sen])['input_ids']
batch_embedding = embedding_model(batch_ids)
baseline_embedding = torch.zeros_like(batch_embedding)
output_idx = torch.tensor([tokenizer.convert_tokens_to_ids(f"Ġdogs")])
explainer = PathExplainerTorch(prediction_model)
attributions = explainer.attributions(
batch_embedding,
baseline_embedding,
output_indices=output_idx,
num_samples=9,
use_expectation=False,
)
sum_attr = attributions.squeeze().detach().numpy().sum(-1) # Sum over embedding dim
text_plot(
sen.split(),
sum_attr,
)
Tensorflow
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.special import log_softmax
from transformers import TFAutoModelForCausalLM, AutoTokenizer
from path_explain import EmbeddingExplainerTF, text_plot
tf_model = TFAutoModelForCausalLM.from_pretrained("distilgpt2")
tf_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
def embedding_model(batch_ids):
batch_embedding = tf_model.transformer.wte(batch_ids)
return batch_embedding
def prediction_model(batch_embedding):
if isinstance(batch_embedding, np.ndarray):
batch_embedding = tf.convert_to_tensor(batch_embedding)
logits = tf_model(None, inputs_embeds=batch_embedding).logits[:, -1]
probs = tf.nn.log_softmax(logits, axis=-1)
return probs
sen = "It is raining cats and"
batch_ids = tf_tokenizer([sen])['input_ids']
batch_embedding = embedding_model(batch_ids)
baseline_embedding = np.zeros_like(batch_embedding)
output_idx = tf_tokenizer.convert_tokens_to_ids("Ġdogs")
explainer = EmbeddingExplainerTF(prediction_model)
attributions = explainer.attributions(
batch_embedding,
baseline_embedding,
output_indices=output_idx,
num_samples=10,
batch_size=1,
use_expectation=False,
)
text_plot(
sen.split(),
attributions[0],
)
from path_explain.
What version of transformers and pytorch are you using? I'm actually getting a weird issue where the gradients with respect to the embedding are really unreasonable in PyTorch while looking totally correct in TF
from path_explain.
Ah that's odd, I'm currently on torch v.1.8.0, tensorflow v2.5.0, and transformers v4.8.2.
from path_explain.
Related Issues (12)
- output_indices not passed through for torch interactions HOT 2
- Can you help update a version with pytorch example? HOT 2
- Unable to use multiprocessing pool with PathExplainerTF HOT 1
- cannot import name 'EmbeddingExplainerTF' HOT 8
- Extend torch interactions to higher dimensions HOT 3
- No convergence for IH for larger input strings HOT 2
- Whether the feature attribution method can be applied on training data?
- can't find bert_explainer
- Using Longformer
- .npy files
- EmbeddingExplainerTorch not available in pip package
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from path_explain.