Giter Club home page Giter Club logo

adversarial_autoencoder's Introduction

Adversarial autoencoders

Cover

This repository contains code to implement adversarial autoencoder using Tensorflow.

Medium posts:

  1. A Wizard's guide to Adversarial Autoencoders: Part 1. Autoencoders?

  2. A Wizard's guide to Adversarial Autoencoders: Part 2. Exploring the latent space with Adversarial Autoencoders.

  3. A Wizard's guide to Adversarial Autoencoders: Part 3. Disentanglement of style and content.

  4. A Wizard's guide to Adversarial Autoencoders: Part 4. Classify MNIST using 1000 labels.

Installing the dependencies

Install virtualenv and creating a new virtual environment:

pip install virtualenv
virtualenv -p /usr/bin/python3 aa

Install dependencies

pip3 install -r requirements.txt

Note:

  • I'd highly recommend using your GPU during training.
  • tf.nn.sigmoid_cross_entropy_with_logits has a targets parameter which has been changed to labels for tensorflow version > r0.12.

Dataset

The MNIST dataset will be downloaded automatically and will be made available in ./Data directory.

Training!

Autoencoder:

Architecture:

To train a basic autoencoder run:

    python3 autoencoder.py --train True
  • This trains an autoencoder and saves the trained model once every epoch in the ./Results/Autoencoder directory.

To load the trained model and generate images passing inputs to the decoder run:

    python3 autoencoder.py --train False

Adversarial Autoencoder:

Architecture:

Cover

Training:

    python3 adversarial_autoencoder.py --train True

Load model and explore the latent space:

    python3 adversarial_autoencoder.py --train False

Example of adversarial autoencoder output when the encoder is constrained to have a stddev of 5.

Cover

Matching prior and posterior distributions.

Adversarial_autoencoder Distribution of digits in the latent space.

Supervised Adversarial Autoencoder:

Architecture:

Cover

Training:

    python3 supervised_adversarial_autoencoder.py --train True

Load model and explore the latent space:

    python3 supervised_adversarial_autoencoder.py --train False

Example of disentanglement of style and content: Cover

Semi-Supervised Adversarial Autoencoder:

Architecture:

Cover

Training:

    python3 semi_supervised_adversarial_autoencoder.py --train True

Load model and explore the latent space:

    python3 semi_supervised_adversarial_autoencoder.py --train False

Classification accuracy for 1000 labeled images:

Cover

Cover

Note:

  • Each run generates a required tensorboard files under ./Results/<model>/<time_stamp_and_parameters>/Tensorboard directory.
  • Use tensorboard --logdir <tensorboard_dir> to look at loss variations and distributions of latent code.
  • Windows gives an error when : is used during folder naming (this is produced during the folder creation for each run).I would suggest you to remove the time stamp from folder_name variable in the form_results() function. Or, just dual boot linux!

Thank You

Please share this repo if you find it helpful.

adversarial_autoencoder's People

Contributors

naresh1318 avatar warvito avatar wkelongws avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

adversarial_autoencoder's Issues

Reproducing quantitative results

Hi Naresh1318,

Thank you very much for the code. I ran your implementation and was able to get the reported results in README.md.
However my question is, have you been able to reach an error rate of 1.90 (±0:10)% for MNIST 100 labels (as reported in original paper)? Even after 5000 epochs the lowest I get is somewhat around 5%.

Bad looking losses

Good Morning!

I am trying to reproduce the very nice results you have in your tutorial but my training do not go as well when I execute python adversarial_autoencoder.py

Here are the losses I obtain:
image
I tried basically to decrease the learning rate without effect.
Can you help me?

Thanks in advance,

Pierre

Encoder latent representation

Thanks for the great tutorial :-)

How can I save the encoder model? I know how that can be done in Keras, but not in TF. Basically I would like to save the latent representation of the input, ex: (n,2) dim latent representation of the data.

categorical discriminator loss compared with 1

dc_c_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_c_fake), logits=d_c_fake))

Hi, I'm trying to implement the adversarial autoencoder myself. I found this repo and use it as a guide. It trains the autoencoder part but I think this line is a bug? Shouldn't the categorical discriminator train like the gaussian one? If it's not a bug could you comment on the training objective? Thanks!

Terminating with uncaught exception of type NSException:

When I'm trying to run the supervised_adversarial_autoencoder.py I get the following uncaught exception:

`2019-02-24 17:00:07.426921: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-02-24 17:00:07.586 python[6976:8326575] -[NSApplication _setup:]: unrecognized selector sent to instance 0x7fd270bef700
2019-02-24 17:00:07.587 python[6976:8326575] *** Terminating app due to uncaught exception 'NSInvalidArgumentException', reason: '-[NSApplication _setup:]: unrecognized selector sent to instance 0x7fd270bef700'
*** First throw call stack:
(
0 CoreFoundation 0x00007fff4f351ecd __exceptionPreprocess + 256
1 libobjc.A.dylib 0x00007fff7b409720 objc_exception_throw + 48
2 CoreFoundation 0x00007fff4f3cf275 -[NSObject(NSObject) __retain_OA] + 0
3 CoreFoundation 0x00007fff4f2f3b40 forwarding + 1486
4 CoreFoundation 0x00007fff4f2f34e8 _CF_forwarding_prep_0 + 120
5 libtk8.6.dylib 0x000000012e3f131d TkpInit + 413
6 libtk8.6.dylib 0x000000012e34917e Initialize + 2622
7 _tkinter.cpython-36m-darwin.so 0x000000012e171a16 _tkinter_create + 1174
8 python 0x000000010e401068 _PyCFunction_FastCallDict + 200
9 python 0x000000010e4d661f call_function + 143
10 python 0x000000010e4d4175 _PyEval_EvalFrameDefault + 46837
11 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
12 python 0x000000010e4d72cc _PyFunction_FastCallDict + 364
13 python 0x000000010e37ff80 _PyObject_FastCallDict + 320
14 python 0x000000010e3a75f8 method_call + 136
15 python 0x000000010e3875ce PyObject_Call + 62
16 python 0x000000010e4285b5 slot_tp_init + 117
17 python 0x000000010e42caf1 type_call + 241
18 python 0x000000010e37fef1 _PyObject_FastCallDict + 177
19 python 0x000000010e388137 _PyObject_FastCallKeywords + 327
20 python 0x000000010e4d6718 call_function + 392
21 python 0x000000010e4d4225 _PyEval_EvalFrameDefault + 47013
22 python 0x000000010e4d69dc fast_function + 188
23 python 0x000000010e4d667c call_function + 236
24 python 0x000000010e4d4175 _PyEval_EvalFrameDefault + 46837
25 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
26 python 0x000000010e4d72cc _PyFunction_FastCallDict + 364
27 python 0x000000010e37ff80 _PyObject_FastCallDict + 320
28 python 0x000000010e3a75f8 method_call + 136
29 python 0x000000010e3875ce PyObject_Call + 62
30 python 0x000000010e4d4376 _PyEval_EvalFrameDefault + 47350
31 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
32 python 0x000000010e4d6a8a fast_function + 362
33 python 0x000000010e4d667c call_function + 236
34 python 0x000000010e4d4175 _PyEval_EvalFrameDefault + 46837
35 python 0x000000010e4d69dc fast_function + 188
36 python 0x000000010e4d667c call_function + 236
37 python 0x000000010e4d4175 _PyEval_EvalFrameDefault + 46837
38 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
39 python 0x000000010e4d6a8a fast_function + 362
40 python 0x000000010e4d667c call_function + 236
41 python 0x000000010e4d4175 _PyEval_EvalFrameDefault + 46837
42 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
43 python 0x000000010e4d6a8a fast_function + 362
44 python 0x000000010e4d667c call_function + 236
45 python 0x000000010e4d4225 _PyEval_EvalFrameDefault + 47013
46 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
47 python 0x000000010e4d6a8a fast_function + 362
48 python 0x000000010e4d667c call_function + 236
49 python 0x000000010e4d4225 _PyEval_EvalFrameDefault + 47013
50 python 0x000000010e4c78c9 _PyEval_EvalCodeWithName + 425
51 python 0x000000010e52055c PyRun_FileExFlags + 252
52 python 0x000000010e51fa34 PyRun_SimpleFileExFlags + 372
53 python 0x000000010e5467c6 Py_Main + 3734
54 python 0x000000010e377f59 main + 313
55 libdyld.dylib 0x00007fff7c4d7ed9 start + 1
56 ??? 0x0000000000000002 0x0 + 2
)
libc++abi.dylib: terminating with uncaught exception of type NSException

Process finished with exit code 134 (interrupted by signal 6: SIGABRT)`

Anyone please help me with this!!

Trainable variables for the generator optimizer

Hi Naresh,

Really appreciate you taking the time to make the AAE tutorial. It is a great read!

I have a question regarding the implementation of generator_optimizer in the code. When I print en_var, I get the following list of variables

e_dense_1/weights:0
e_dense_1/bias:0
e_dense_2/weights:0
e_dense_2/bias:0
e_latent_variable/weights:0
e_latent_variable/bias:0
d_dense_1/weights:0
d_dense_1/bias:0
d_dense_2/weights:0
d_dense_2/bias:0

In your post, you mention that:

We’ll backprop only through the encoder weights, which causes the encoder to learn the required distribution and produce output which’ll have that distribution.

Do the decoder weights get updated as well?

Can Dimension of latent representation get bigger? For example, 10, 128

Thank you for your code!
Sorry, I don't know much about TF.
When I tried the unsupervised experiment, I tried to improve the dimension of latent representation(like, 10, 128), but the generator was completely broken, and I failed to try to modify it. I wonder if you have tried, could you please provide some information? Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.