Giter Club home page Giter Club logo

keras-adda's Introduction

keras-adda

This is an implementation of Adversarial Discriminative Domain Adaptation https://arxiv.org/abs/1702.05464 using Keras for the purpose of visualizing the intermediate activations in a variational approximation of total correlation.

The class implementing ADDA is in adda.py.

usage: adda.py [-h] [-s SOURCE_WEIGHTS] [-e START_EPOCH]
               [-n DISCRIMINATOR_EPOCHS] [-f]
               [-a SOURCE_DISCRIMINATOR_WEIGHTS]
               [-b TARGET_DISCRIMINATOR_WEIGHTS] [-t EVAL_SOURCE_CLASSIFIER]
               [-d EVAL_TARGET_CLASSIFIER]

optional arguments:
  -h, --help            show this help message and exit
  -s SOURCE_WEIGHTS, --source_weights SOURCE_WEIGHTS
                        Path to weights file to load source model for training
                        classification/adaptation
  -e START_EPOCH, --start_epoch START_EPOCH
                        Epoch to begin training source model from
  -n DISCRIMINATOR_EPOCHS, --discriminator_epochs DISCRIMINATOR_EPOCHS
                        Max number of steps to train discriminator
  -f, --train_discriminator
                        Train discriminator model (if TRUE) vs Train source
                        classifier
  -a SOURCE_DISCRIMINATOR_WEIGHTS, --source_discriminator_weights SOURCE_DISCRIMINATOR_WEIGHTS
                        Path to weights file to load source discriminator
  -b TARGET_DISCRIMINATOR_WEIGHTS, --target_discriminator_weights TARGET_DISCRIMINATOR_WEIGHTS
                        Path to weights file to load target discriminator
  -t EVAL_SOURCE_CLASSIFIER, --eval_source_classifier EVAL_SOURCE_CLASSIFIER
                        Path to source classifier model to test/evaluate
  -d EVAL_TARGET_CLASSIFIER, --eval_target_classifier EVAL_TARGET_CLASSIFIER
                        Path to target discriminator model to test/evaluate

Run Source Classifier on MNIST

To run the source encoder model on MNIST, use:

python adda.py [-e START_EPOCH]

Run Target Discriminator on MNIST and SVHN:

To run the discriminator model on MNIST and SVHN to increase domain confusion, use:

python adda.py -f [-s SOURCE_WEIGHTS] [-n DISCRIMINATOR_EPOCHS] [-a SOURCE_DISCRIMINATOR_WEIGHTS] [-b TARGET_DISCRIMINATOR_WEIGHTS]

Evaluate Source Classifier on MNIST:

python adda.py -t SOURCE_CLASSIFIER_WEIGHTS

Evaluate Target Classifier on SVHN based on Domain Confusion:

python adda.py -t SOURCE_CLASSIFIER_WEIGHTS -d TARGET_DISCRIMINATOR_WEIGHTS

keras-adda's People

Contributors

hanskrupakar avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.