Giter Club home page Giter Club logo

causal-bert-pytorch's Introduction

Causal Bert -- in Pytorch!

Pytorch implementation of "Adapting Text Embeddings for Causal Inference" by Victor Veitch, Dhanya Sridhar, and David M. Blei.

Quickstart

pip install -r requirements.txt
python CausalBert.py

This will train a system on some test data and calculate an average treatment effect (ATE).

Description

As input this system expects data where each row consists of:

  • Freeform text
  • A categorical variable (numerically coded) representing a confound
  • A binary treatment variable
  • A binary outcome variable

Then the system will give the text to BERT, and use the BERT embeddings + confound to predict

  1. P(T | C, text)
  2. P(Y | T = 1, C, text)
  3. P(Y | T = 0, C, text)
  4. The original masked language modeling objective of BERT.

Once trained the resulting BERT embeddings will be sufficient for some causal inferences.

Example

df = pd.read_csv('testdata.csv')            
cb = CausalBertWrapper(batch_size=2,                       # init a model wrapper
    g_weight=0.1, Q_weight=0.1, mlm_weight=1)
cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1)  # train the model
print(cb.ATE(df['C'], df['text'], platt_scaling=True))     # use the model to get an average treatment effect

Usage

Initialize the model wrapper (handles training and inference):

cb = CausalBertWrapper(
  batch_size=2,   # batch size for training
  g_weight=1.0,   # loss weight for P(T | C, text) prediction head
  Q_weight=0.1,   # loss weight for P(Y | T, C, text) prediction heads
  mlm_weight=1)   # loss weight for original MLM objective

Then train

cb.train(
  df['text'],    # list of texts
  df['C'],       # list of confounds
  df['T'],       # list of treatments
  df['Y'],       # list of outcomes
  epochs=1)      # training epochs

Perform inference

( ( P(Y=1|T=1), P(Y=0|T=1)), ( P(Y=1|T=0), P(Y=0|T=0) ), ... =  cb.inference(
  df['text'],   # list of texts
  df['C'])      # list of confounds

Or estimate an average treatment effect

ATE = cb.ate(
  df['text'],   # list of texts
  df['C'],      # list of confounds
  platt_scailing=False)    # https://en.wikipedia.org/wiki/Platt_scaling

causal-bert-pytorch's People

Contributors

reidpryzant avatar rpryzant avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.