Giter Club home page Giter Club logo

cfml_tools's Introduction

cfml_tools: Counterfactual Machine Learning Tools

For a long time, ML practitioners and statisticians repeated the same mantra: Correlation is not causation. This warning prevented (at least, some) people from drawing wrong conclusions from models. However, it also created a misconception that models cannot be causal. With a few tweaks drawn from the causal inference literature (DeepIV, Generalized Random Forests, Causal Trees) and Reinforcement Learning literature (Bandits, Thompson Sampling) we actually can make machine learning methods aware of causality!

cfml_tools is my collection of causal inference algorithms built on top of accessible, simple, out-of-the-box ML methods, aimed at being explainable and useful in the business context.

Installation

Open up your terminal and perform:

git clone https://github.com/gdmarmerola/cfml_tools.git
cd cfml_tools
python setup.py install

Basic Usage

Use the example dataset to test the package:

from cfml_tools.utils import make_confounded_data

# compliments to nubank fklearn library
df_rnd, df_obs, df_cf = make_confounded_data(500000)

# organizing data into X, W and y
X = df_obs[['sex','age','severity']]
W = df_obs['medication'].astype(int)
y = df_obs['recovery']

Our end result is to get counterfactual predictions. This is obtained by using the .predict() method, which returns a dataframe with the expected outcome for each treatment, when it is possible to calculate it. The package uses a scikit-learn like API, and is fairly easy to use.

# importing cfml-tools
from cfml_tools.tree import DecisionTreeCounterfactual

# instance of DecisionTreeCounterfactual
dtcf = DecisionTreeCounterfactual(save_explanatory=True)

# fitting model to our data
dtcf.fit(X, W, y)

# predicting counterfactuals
counterfactuals = dtcf.predict(X)
counterfactuals.iloc[5:10]

DecisionTreeCounterfactual, in particular, builds a decision tree to solve a regression or classification problem from explanatory variables X to target y, and then compares outcomes for every treatment W at each leaf node to build counterfactual predictions. It yields great results on fklearn's causal inference problem out-of-the-box.

and it is very fast (500k records dataset):

Additional features

Run cross validation to get X to y prediction results

Check how the underlying model predicts the outcome, regardless of treatment (this should be as high as possible):

cv_scores = dtcf.get_cross_val_scores(X, y)
print(cv_scores)
[0.55007156 0.54505553 0.54595812 0.55107778 0.5513648 ]

Explain counterfactual predictions using leaves

The .explain() method provides explanations by using the elements on the leaf node used to perform the counterfactual prediction.

# sample we wish to explain
test_sample = X.iloc[[5000]]

# running explanation
comparables_table = dtcf.explain(test_sample)
comparables_table.groupby('W').head(2)
index sex age severity W y
5000 1 36 0.23 0 52
5719 1 35 0.25 0 38
23189 1 37 0.22 1 13
35839 1 35 0.25 1 11

This way we can compare samples and check if we can rely on the effect being calculated. In this particular case, it seems that we can rely on the prediction, as we have very similar individuals on treated and untreated groups:

fig, ax = plt.subplots(1, 4, figsize=(16, 5), dpi=150)
comparables_table.boxplot('age','W', ax=ax[0])
comparables_table.boxplot('sex','W', ax=ax[1])
comparables_table.boxplot('severity','W', ax=ax[2])
comparables_table.boxplot('y','W', ax=ax[3])

[Experimental] Further criticize the model using leaf diagnostics

We can inspect the model further by using the .run_leaf_diagnostics() method.

# running leaf diagnostics
leaf_diagnostics_df = dtcf.run_leaf_diagnostics()
leaf_diagnostics_df.head()

The method provides a diagnostic on leaves valid for counterfactual inference, showing some interesting quantities:

  • average outcomes across treatments (avg_outcome)
  • explanatory variable distribution across treatments (percentile_* variables)
  • a confounding score for each variable, meaning how much we can predict the treatment from explanatory variables inside leaf nodes using a linear model (confounding_score)

Particularly, confounding_score tells us if treatments are not randomly assigned given explanatory variables, and it is a big source of bias in causal inference models. As this score gets bigger, we tend to miss the real effect more:

[Experimental] Better visualize and understand your problem with forest embeddings

Besides DecisionTreeCounterfactual, we provide ForestEmbeddingsCounterfactual, which still is at an experimental phase. A cool thing to do with this model is plot the forest embeddings of your problem. The method uses leaf co-occurence as similarity metric and UMAP for dimensionality reduction.

# getting embedding from data
reduced_embed = fecf.get_umap_embedding(X)

This allows for some cool visuals for diagnosing your problem, such as the distribution of features across the embedding:

or how treatments and outcomes are distributed to check "where" inference is valid:

Additional resouces

You can check several additional resources:

I hope you'll use cfml_tools for your causal inference problems soon! All feedback is appreciated :)

cfml_tools's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cfml_tools's Issues

Different performance when reproducing your example

Hello, first of all, thanks for making the code available!

I am having trouble reproducing your example. As you can see in this notebook https://github.com/millengustavo/causality/blob/master/examples/causal_diagrams.ipynb, I copied the definition code of the classes present in init.py from your repository and tried to apply it in the same dataset generated by fklearn. However, I obtained different results in the "observational" data. The performance was quite different.

Could you point me to the reason? I tested with different sample sizes without success.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.