Giter Club home page Giter Club logo

conditional-wassersteingan's Introduction

Conditional WassersteinGAN with Auxiliary Classifier

Tensorflow implementation of a conditional WassersteinGAN. Based on the recent paper Wasserstein GAN.

Note: I will interchangably use term discriminators and critic.

  1. Prerequisites
  2. Network Structure
  3. Thoughts
  4. Result
  5. References

Prerequisites

  • Code was tested with Linux
  • Tensorflow v1.0. Older versions may need to do some renaming manually.
  • To train, call main() in WGAN_AC.py. To test, call test().

Network Structure

  • The network structure inside generator and discriminator are built in a way similar to DCGAN, which is different from the PyTorch implementation by the authors.
  • Because of the characteristics of WassersteinGAN, I have to use a separate classifier to enforce the condition. See next section.
  • Generator --------> Discriminator ---> Wasserstein Distance
    |
    |----> Classifier ---> class1, class2, ..., classN

Thoughts

  • The WassersteinGAN solved several problems in traditional GAN, including mode collapse and the instability in training. So I think it is worth investigating a little more to train a conditional WassersteinGAN.
  • Initial tries on training conditional WassersteinGAN all failed because the critic in WassersteinGAN is no longer doing classification, so the traditional ways like CGAN will not work.
  • The separate classifier is inspired by Conditional Image Synthesis With Auxiliary Classifier GANs. In their original structure, the discriminator will also output classification loss and optimize that together with loss on real/fake, but the trick does not work here because the classification loss will be dominated by the Wasserstein distance.
  • One drawback of using a separate classifier is that the balance between the generator and the classifier need to be noticed. Because the training on classifier also affect the generator, if the classifier is trained for too many iterations, it will make the discriminator hard to converge. In my implementation I just use some simple math trick so that proportion of number of iterations on the generator and the discriminator still remains roughly 1:5, as in the original implementation of WassersteinGAN.
  • The training is slower than the original WassersteinGAN. This makes sense because it will take the optimizer more effort to find the correct direction. I increased the learning rate so the training can be a little faster.
  • The change of losses during training:

Result

  • As you can see, the results are not perfect. Some categories still have some other digits mixed into them. This may be caused by the conflict between Wasserstein distance and the classification loss function(which is cross entropy), but I have not found a better way to solve the problem. More iterations could improve the results a little more, but not much.

References

  • The original tensorflow implementation of WassersteinGAN is adopted from Zardinality's work - link
  • AC-GAN implementation by buriburisuri - link

conditional-wassersteingan's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.