Giter Club home page Giter Club logo

ecbm's Introduction

Energy-Based Concept Bottleneck Models: Unifying Prediction, Concept Intervention, and Probabilistic Interpretations (ECBMs)

This repo is the official implementation of our ICLR 2024 paper:

Energy-Based Concept Bottleneck Models: Unifying Prediction, Concept Intervention, and Probabilistic Interpretations

Xinyue Xu, Yi Qin, Lu Mi, Hao Wang, Xiaomeng Li

Twelfth International Conference on Learning Representations (ICLR), 2024.

[Paper] [OpenReview)] [PPT]

Overview of our ECBM

Top: During training, ECBM learns positive concept embeddings (in black), negative concept embeddings (in white), class embeddings (in black), and the three energy networks by minimizing the three energy functions, using the total loss function. The concept and class label are treated as constants.

Bottom: During inference, we (1) freeze all concept and class embeddings as well as all networks, and (2) update the predicted concept probabilities and class probabilities by minimizing the three energy functions using the total loss function.

Installation

Prerequisites

We run all experiments on NVIDIA RTX3090 GPU.

pip install -r requirements.txt

Dataset Preperation

Please specify the dataset folder path at data_util.py

Configuration

Configurations are in {dataset}/{dataset_inference}.json file.

  • Select dataset, set dataset='TARGET DATASET'.
  • If using pretrained weight, pretrained = true.
  • emb_size: the feature size after the feature encoder.
  • hid_size: projected feature size.
  • cpt_size: the number of concepts.

Run Experiments

1. Training

Training our ECBM, please run

python main.py --dataset [cub/awa2/celeba]

2. Inference

Running the gradient inference, please specify the trained weight at exp folder (change "trained_weight" to the last ckpt):

python GradientInference.py --dataset [cub/awa2/celeba]

3. Interventions

Individual Intervention

python GradientInference.py --dataset [cub/awa2/celeba] --intervene_type individual --missingratio [0.1, 0.9]

OR

./run_intervene_missing.sh

Group Intervention

Only for CUB dataset, CelebA and AWA2 do not have grouped concepts.

python GradientInference.py --dataset cub --intervene_type group --missingratio [0.1, 0.9]

4. Interpretations

Proposition 3.2: Use CalcImportanceScore.py to generate c_gt/c_pred/y_gt/y_pred.npy.

Proposition 3.3/3.4/3.5: Plot heatmaps by plot_correlation.ipynb and plot_joint.ipynb.

Results

Prediction

Accuracy on Different Datasets. We report the mean and standard deviation from five runs with different random seeds. For ProbCBM (marked with โ€œ*โ€), we report the best results from the ProbCBM paper (Kim et al., 2023) for CUB and AWA2 datasets.

Concept Intervention

Performance with different ratios of intervened concepts on three datasets (with error bars). The intervention ratio denotes the proportion of provided correct concepts. We use CEM with RandInt. CelebA and AWA2 do not have grouped concepts; thus we adopt individual intervention.

Conditional Interpretations

Marginal concept importance for top 3 concepts of 4 different classes computed using Proposition 3.2. ECBM's estimation (Ours) is very close to the ground truth (Oracle).

Reference

@inproceedings{ECBM,
      title={Energy-Based Concept Bottleneck Models: Unifying Prediction, Concept Intervention, and Probabilistic Interpretations}, 
      author={Xu, Xinyue and Qin, Yi and Mi, Lu and Wang, Hao and Li, Xiaomeng},
      booktitle={International Conference on Learning Representations},
      year={2024}
}

ecbm's People

Contributors

thea-hsu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

wang-ml-lab

ecbm's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.