Giter Club home page Giter Club logo

debug-mistakes-cce's Introduction

Meaningfully debugging model mistakes with conceptual counterfactual explanations

What is this work about?

Understanding model mistakes is critical to many machine learning objectives yet this is still an ad hoc process that involves manually looking at the model’s mistakes on many test samples and guessing the underlying reasons for those incorrect predictions. With CCE, we would like to take a step towards systematically analyzing model mistakes using high-level human understandable terms, i.e. concepts!

Here is an overview of our framework, and you can find more in our ICML 2022 paper Meaningfully Debugging Model Mistakes using Conceptual Counterfactual Explanations. Overview

This is joint work with the wonderful Abubakar Abid and James Zou.

Here we aim to make the code as easy to use as possible, so that you could go ahead and try conceptual counterfactual explanations without too much overhead!

How to generate conceptual counterfactuals?

Learning a concept bank

First, you need to define a concept bank for your problem. In our experiments, we derive the Broden Concept Bank which consists of 170 concepts.

In examples/resnet18_bank.pkl we release the concept bank and concept statistics for a ResNet18. If you wish to train your own concept banks, you can use the script learn_concepts.py.

python learn_concepts.py --concept-data=/user/concept-folder/ --model-name resnet18 --C=0.001 --output-folder=banks/

For details, please refer to the learn_concepts.py script and/or our paper.

Generate counterfactual explanations

Given a concept bank and a model, you can run generate_counterfactuals.py to generate the explanations. For instance, the command

python3 generate_counterfactuals.py --concept-bank=examples/resnet18_bank.pkl --image-folder=examples/images/ --model-path="./examples/models/dog(snow).pth"

runs the script to generate counterfactual explanations for the images in the sample_images folder for the ResNet18 model. The output should be saved to examples/{image_name}_explanation.png file for each image. For the examples that we share, upon running the command above, the output image should look like the following:

Example Explanation
We took the example image from the LAION5B dataset (you should check it out, it's v cool).

Metashift Experiments:

First, please obtain the Metashift using instructions in the Metashift Repository.

A few things to do to run Metashift experiments:

Please go to /metashift/constants.py to replace the pointers to the dataset. Particularly, the METASHIFT_ROOT variable should be replaced with to the dataset, with METASHIFT_ROOT/allImages/images/ and METASHIFT_ROOT/full-candidate-subsets.pkl/.

How do we train spuriously correlated models?

There are a few simple commands in /metashift/run_experiments.sh. For instance,

python3 train.py --dataset="bear-bird-cat-dog-elephant:dog(snow)"

will run the 5-class classification model, where during training, all of the dog images are sampled from the snow domain.

Do I have to train models from scratch?

For convenience, we also share checkpoints for some of the model we trained. You can find them in /examples/. For instance, /examples/models/dog(snow).pth is the model for the dog(snow) scenario.

How do we evaluate CCE?

Similarly, see /metashift/run_experiments.sh. You can either specify the model with --model-path, or use the same output directory where you have the confounded model. For instance, you can use

python3 evaluate_cce.py --dataset="bear-bird-cat-dog-elephant:dog(snow)" --concept-bank="../examples/resnet18_bank.pkl" --model-path="../examples/dog(snow).pth"

to evaluate the model that we provide with this repository. Or, if you have trained your own models, you can use

python3 evaluate_cce.py --dataset="bear-bird-cat-dog-elephant:dog(snow)" --concept-bank=/your/bank/path/ --out-dir=/your/output/directory/

to evaluate your model.

Contact

If you have any concerns or questions, please reach out to me at [email protected].

If you find this work useful, please consider citing our ICML 2022 paper:

@InProceedings{abid22a,
  title = 	 {Meaningfully debugging model mistakes using conceptual counterfactual explanations},
  author =       {Abid, Abubakar and Yuksekgonul, Mert and  Zou, James},
  booktitle = 	 {Proceedings of the 39th International Conference on Machine Learning},
  year = 	 {2022},
  volume = 	 {162},
  series = 	 {Proceedings of Machine Learning Research},
  publisher =    {PMLR},
  url = 	 {https://proceedings.mlr.press/v162/abid22a.html}
}

debug-mistakes-cce's People

Contributors

conceptualcounterfactuals avatar mertyg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

debug-mistakes-cce's Issues

[Bug] MetashiftManager loads wrong images

First of all, thank you for your work.

I am re-implementing your codes for spurious detection experiments.

By the way, I found that the wrong evaluation images are used.

Load image list

from dataset import MetashiftManager

dataset_name = "bear-bird-cat-dog-elephant:dog(snow)"
manager = MetashiftManager()
classes, train_domain = dataset_name.split(":")
classes = classes.split("-")
num_classes = len(classes)
shift_class = train_domain.split("(")[0]
spurious_concept = train_domain.split("(")[1][:-1].lower()
print(f"Shift Class: {shift_class}, Spurious Concept: {spurious_concept}")

Result:

Shift Class: dog, Spurious Concept: snow

Get image list for all classes

class_images = manager.get_class_ims(classes)
len(class_images["dog"])

Result:

2580

Check images in "dog" class

from PIL import Image
Image.open(class_images["dog"][2077])

image

from PIL import Image
Image.open(class_images["dog"][1616])

image

For classification, there is no "dog" in an image cropped from this image.

Many images in class "dog" are unrelated to class "dog".

Please let me know the reason.

Is your experiment wrong?

I found that the wrong images contain hotdogs...

Code release

Excellent work! Can I know when you will release the code?

computing W_clamp_min

Hello,
I'm a PhD student working on XAI. And I came across your paper so I decided to test the provided implementation, which I thank you for. The code is very clear and easy to understand !

While studying the code, I noticed the following:
In cce_utils.py, line 68, shouldn't W_clamp_min be computed as

  • (W_clamp_min / (min_margins * concept_norms)).T
  • instead of (W_clamp_min / (max_margins * concept_norms)).T ?

I'm saying this based on equation (6) provided in section 3.2 of your paper. Please correct me if I'm wrong or if I misunderstood something.

Thank you in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.