Giter Club home page Giter Club logo

Comments (7)

jjanizek avatar jjanizek commented on May 23, 2024 1

I think it would be good if num_samples acts the same for both implementations.

Yes, i think that's a great point and the two implementations should have the same behavior. I think I'd tend towards having num_samples represent the total number of gradient calls the explainer is doing, so I'll probably change the PathExplainerTorch implementation to more closely match the _sample_alphas method of PathExplainerTF

from path_explain.

jjanizek avatar jjanizek commented on May 23, 2024 1

Great -- apparently there's some issue with torch v.1.6 and transformers v4.8, now that I've upgraded to torch v.1.9 the gradients appear to be of the correct magnitude and consistent with what I've seen in tensorflow (all our NLP experiments were initially done in TF, which is why we never caught this).

I think you're correct that for the sake of consistency, the "+1" should be removed from the PyTorch explainer so that it instead looks like this:
scaled_inputs = [reference_tensor + (float(i)/num_samples)*(input_expand - reference_tensor) \ for i in range(0,num_samples)]

Again, feel free to make a pull request if you'd like, and nice catch!

from path_explain.

jjanizek avatar jjanizek commented on May 23, 2024

Would you be willing to send me a snippet of the code for the HuggingFace model you've been looking at in TF and PyTorch? I was planning to do some maintenance on the explainers this week, and that seems like a very good test for ensuring consistency of behavior between the two explainers

from path_explain.

jumelet avatar jumelet commented on May 23, 2024

Yes, no problem at all, I'll quickly set up a code snippet

from path_explain.

jumelet avatar jumelet commented on May 23, 2024

I've been testing your code on a causal language model, DistilGPT-2, on a simple input sentence ("It's raining cats and dogs"). The attributions that are computed here are for "dogs" based on the prefix "It's raining cats and".

Torch

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from path_explain import PathExplainerTorch, text_plot

model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

def embedding_model(batch_ids):
    batch_embedding = model.transformer.wte.weight[[batch_ids]]
    return batch_embedding

def prediction_model(batch_embedding):
    logits = model(inputs_embeds=batch_embedding).logits[:, -1]
    probs = torch.log_softmax(logits, dim=-1)
    
    return probs

sen = "It is raining cats and"

batch_ids = tokenizer([sen])['input_ids']
batch_embedding = embedding_model(batch_ids)
baseline_embedding = torch.zeros_like(batch_embedding)
output_idx = torch.tensor([tokenizer.convert_tokens_to_ids(f"Ġdogs")])

explainer = PathExplainerTorch(prediction_model)

attributions = explainer.attributions(
    batch_embedding,
    baseline_embedding,
    output_indices=output_idx,
    num_samples=9,
    use_expectation=False,
)

sum_attr = attributions.squeeze().detach().numpy().sum(-1)  # Sum over embedding dim

text_plot(
    sen.split(),
    sum_attr,
)

Tensorflow

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from scipy.special import log_softmax
from transformers import TFAutoModelForCausalLM, AutoTokenizer

from path_explain import EmbeddingExplainerTF, text_plot

tf_model = TFAutoModelForCausalLM.from_pretrained("distilgpt2")
tf_tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

def embedding_model(batch_ids):
    batch_embedding = tf_model.transformer.wte(batch_ids)
    return batch_embedding

def prediction_model(batch_embedding):
    if isinstance(batch_embedding, np.ndarray):
        batch_embedding = tf.convert_to_tensor(batch_embedding)
    logits = tf_model(None, inputs_embeds=batch_embedding).logits[:, -1]
    probs = tf.nn.log_softmax(logits, axis=-1)
    
    return probs

sen = "It is raining cats and"

batch_ids = tf_tokenizer([sen])['input_ids']
batch_embedding = embedding_model(batch_ids)

baseline_embedding = np.zeros_like(batch_embedding)

output_idx = tf_tokenizer.convert_tokens_to_ids("Ġdogs")

explainer = EmbeddingExplainerTF(prediction_model)
attributions = explainer.attributions(
    batch_embedding,
    baseline_embedding,
    output_indices=output_idx,
    num_samples=10,
    batch_size=1,
    use_expectation=False,
)

text_plot(
    sen.split(),
    attributions[0],
)

from path_explain.

jjanizek avatar jjanizek commented on May 23, 2024

What version of transformers and pytorch are you using? I'm actually getting a weird issue where the gradients with respect to the embedding are really unreasonable in PyTorch while looking totally correct in TF

from path_explain.

jumelet avatar jumelet commented on May 23, 2024

Ah that's odd, I'm currently on torch v.1.8.0, tensorflow v2.5.0, and transformers v4.8.2.

from path_explain.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.