Giter Club home page Giter Club logo

causal-bert-pytorch's People

Contributors

reidpryzant avatar rpryzant avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

causal-bert-pytorch's Issues

Possible reversed subtraction in the `.ATE()` method

Hi @rpryzant

Thank you for sharing your implementation of CausalBERT!

I tried it recently. I created a number of synthetic datasets and I found the results surprising when using the .ATE() method.

The effect was systematically reversed comparing to the expectation.

I went through the code and it seems to me that there's a possible reversed subtraction in the return statement of the .ATE() method:

    def ATE(self, C, W, Y=None, platt_scaling=False):
        Q_probs, _, Ys = self.inference(W, C, outcome=Y)
        if platt_scaling and Y is not None:
            Q0 = platt_scale(Ys, Q_probs[:, 0])[:, 0]
            Q1 = platt_scale(Ys, Q_probs[:, 1])[:, 1]
        else:
            Q0 = Q_probs[:, 0]
            Q1 = Q_probs[:, 1]

        return np.mean(Q0 - Q1)

According to the original paper Q1 represents the outcome under the treatment (Y|do(T=1)), while Q0 the outcome under no treatment (Y|do(T=0)).

We usually define ATE as E[P(Y|do(T=1)) - P(Y|do(T=0))]

The .ATE() method returns np.mean(Q0 - Q1) which seems to be reversed subtraction.

What are your thoughts on this?

Why choose the last column of Q0, Q1 and g

Hi Reid, thanks for your implementation. I have two questions about the code.

I was wondering why the last column of Q0, Q1, and g was selected in the code of lines 165-167.

Q0 = sm(Q_logits_T0)[:, 1]
Q1 = sm(Q_logits_T1)[:, 1]
g = sm(g)[:, 1]

Would that be equal to calculate the P(Y=1 |T=0, C, text) and P(Y=1 |T=1, C, text) and P(T=1 | C, text)?

However, according to the definition in the paper, what we should calculate are P(Y|T=0, C, text), P(Y |T=1, C, text), and P(T | C, text) ? Please correct me if I'm wrong.

Another question is why the Y labels of T=0 and T=1 should be set to -100 when calculating the cross-entropy of Q1 and Q0?

T0_indices = (T == 0).nonzero().squeeze()
Y_T1_labels = Y.clone().scatter(0, T0_indices, -100)
T1_indices = (T == 1).nonzero().squeeze()
Y_T0_labels = Y.clone().scatter(0, T1_indices, -100)
Q_loss_T1 = CrossEntropyLoss()(Q_logits_T1.view(-1, self.num_labels), Y_T1_labels)
Q_loss_T0 = CrossEntropyLoss()(Q_logits_T0.view(-1, self.num_labels), Y_T0_labels)

I'm a freshman in causal inference. Really appreciated it if you can help me out!

Specify versions in pip requirements?

Hi Reid,

Thanks for making this package! Can you specify the version numbers in requirements.txt?

Otherwise, when running python CausalBert.py I get errors like:

Traceback (most recent call last):
  File "CausalBert.py", line 19, in <module>
    from transformers.modeling_bert import BertPreTrainingHeads
ModuleNotFoundError: No module named 'transformers.modeling_bert'

Thanks,
Katie

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.