Giter Club home page Giter Club logo

representation-causal-public's Issues

File missing

The README says "To run the study that learns disentangled representation with VAE+IOSS, run the src/run_disentanglement_learn.sh", but there is no file named run_disentanglement_learn.sh in the 'src' folder.

TestPredNet Loss Calculation is NOT in the Loop

In several places where testpredNet optimization is involved, the training loop seems to be mis allocated:

    # below, currently only predict using the learned representation, consider including vaez too.
    baselinevaetestpred_trainX, baselinevaetestpred_trainy = baselinevaemlp(envs[1]['vaez'][::2])[0], envs[1]['labels'][::2]
    baselinevaetestpred_testX, baselinevaetestpred_testy = baselinevaemlp(envs[1]['vaez'][1::2])[0], envs[1]['labels'][1::2]
    baselinevaeprediction = baselinevaetestprednet(baselinevaetestpred_trainX)     # input x and predict based on x
    baselinevaeloss = loss_func(nn.Sigmoid()(baselinevaeprediction), baselinevaetestpred_trainy)     # must be (1. nn output, 2. target)


    for t in range(200):
        optimizer_baselinevaetestpred.zero_grad()   # clear gradients for next train
        baselinevaeloss.backward(retain_graph=True)         # backpropagation, compute gradients
        optimizer_baselinevaetestpred.step()        # apply gradients

It should be:

    # below, currently only predict using the learned representation, consider including vaez too.
    baselinevaetestpred_trainX, baselinevaetestpred_trainy = baselinevaemlp(envs[1]['vaez'][::2])[0], envs[1]['labels'][::2]
    baselinevaetestpred_testX, baselinevaetestpred_testy = baselinevaemlp(envs[1]['vaez'][1::2])[0], envs[1]['labels'][1::2]
    


    for t in range(200):
        baselinevaeprediction = baselinevaetestprednet(baselinevaetestpred_trainX)     # input x and predict based on x
        baselinevaeloss = loss_func(nn.Sigmoid()(baselinevaeprediction), baselinevaetestpred_trainy)     # must be (1. nn output, 2. target)
        optimizer_baselinevaetestpred.zero_grad()   # clear gradients for next train
        baselinevaeloss.backward(retain_graph=True)         # backpropagation, compute gradients
        optimizer_baselinevaetestpred.step()        # apply gradients

Confusion between Supervised and Unsupervised Parts

In the sec2-4-3-1-colored_mnist, the supervised part are mixed with the unsupervised part. For instance, the supervised methodology "Causal_Rep" is found in "colored_mnist_unsupervised_expm.py"

This issue led to confusion. Please consider to fix it.

Rebased.

0xa8cECDAd830E35a20492a098DdB1269cE4d5E7d2

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.