Giter Club home page Giter Club logo

gans-implementations's Introduction

GANs-Implementations

Implement GANs with PyTorch.

Progress

Unconditional image generation (CIFAR-10):

  • DCGAN (vanilla GAN)
  • DCGAN + R1 regularization
  • WGAN
  • WGAN-GP
  • SNGAN
  • LSGAN

Conditional image generation (CIFAR-10):

  • CGAN
  • ACGAN

Unsupervised decomposition (MNIST, FFHQ):

  • InfoGAN
  • EigenGAN

Mode collapse study (Ring8, MNIST):

  • GAN (vanilla GAN)
  • GAN + R1 regularization
  • WGAN
  • WGAN-GP
  • SNGAN
  • LSGAN
  • VEEGAN

Unconditional Image Generation

Notes:

Model G. Arch. D. Arch. Loss Configs
DCGAN SimpleCNN SimpleCNN Vanilla config file
DCGAN + R1 reg SimpleCNN SimpleCNN Vanilla
R1 regularization
config file
Additional args--train.loss_fn.params.lambda_r1_reg 10.0
WGAN SimpleCNN SimpleCNN Wasserstein
(weight clipping)
config file
WGAN-GP SimpleCNN SimpleCNN Wasserstein
(gradient penalty)
config file
SNGAN SimpleCNN SimpleCNN (SN) Vanilla config file
SNGAN SimpleCNN SimpleCNN (SN) Hinge config file
LSGAN SimpleCNN SimpleCNN Least Sqaure config file
  • SN stands for "Spectral Normalization".

  • For simplicity, the network architecture in all experiments is SimpleCNN, namely a stack of nn.Conv2d or nn.ConvTranspose2d layers. The results can be improved by adding more parameters and using advanced architectures (e.g., residual connections), but I decide to use the simplest setup here.

  • All models except LSGAN are trained for 40k generator update steps. However, the optimizers and learning rates are not optimized for each model, so some models may not reach their optimal performance.

Quantitative results:

Model FID ↓ Inception Score ↑
DCGAN 24.7311 7.0339 ± 0.0861
DCGAN + R1 reg 24.1535 7.0188 ± 0.1089
WGAN 49.9169 5.6852 ± 0.0649
WGAN-GP 28.7963 6.7241 ± 0.0784
SNGAN (vanilla loss) 24.9151 6.8838 ± 0.0667
SNGAN (hinge loss) 28.5197 6.7429 ± 0.0818
LSGAN 28.4850 6.7465 ± 0.0911
  • The FID is calculated between 50k generated samples and the CIFAR-10 training split (50k images).
  • The Inception Score is calculated on 50k generated samples.

Visualization:

DCGAN DCGAN + R1 reg WGAN WGAN-GP
SNGAN (vanilla loss) SNGAN (hinge loss) LSGAN

Conditional Image Generation

Notes:

Model G. Arch. D. Arch. G. cond. D. cond. Loss Configs & Args
CGAN SimpleCNN SimpleCNN concat concat Vanilla config file
CGAN (cBN) SimpleCNN SimpleCNN cBN concat Vanilla config file
ACGAN SimpleCNN SimpleCNN cBN AC Vanilla config file
  • cBN stands for "conditional Batch Normalization"; SN stands for "Spectral Normalization"; AC stands for "Auxiliary Classifier"; PD stands for "Projection Discriminator".

Quantitative results:

Model FID ↓ Intra FID ↓ Inception Score ↑
CGAN 25.4999 47.7334
Details

Class 0: 53.4163

Class 1: 44.3311

Class 2: 53.1971

Class 3: 52.2223

Class 4: 36.9577

Class 5: 65.0020

Class 6: 37.9598

Class 7: 48.3610

Class 8: 41.8075

Class 9: 44.0796

7.5597 ± 0.0909
CGAN (cBN) 25.3466 47.4136
Details

Class 0: 51.5959

Class 1: 46.6855

Class 2: 49.9857

Class 3: 53.6737

Class 4: 35.1658

Class 5: 65.7719

Class 6: 38.0958

Class 7: 44.7279

Class 8: 43.3078

Class 9: 45.1265

7.7541 ± 0.0944
ACGAN 19.9154 49.9892
Details

Class 0: 47.3203

Class 1: 38.6481

Class 2: 62.5885

Class 3: 66.2386

Class 4: 64.5535

Class 5: 60.7876

Class 6: 58.9524

Class 7: 36.8940

Class 8: 28.5964

Class 9: 35.3120

7.9903 ± 0.1038
  • The FID is calculated between 50k generated samples (5k for each class) and the CIFAR-10 training split (50k images).
  • The intra FID is calculated between 5k generated samples and CIFAR-10 training split within each class.
  • The Inception Score is calculated on 50k generated samples.

Visualizations:

CGAN CGAN (cBN) ACGAN

Unsupervised Decomposition

InfoGAN

  • Left: change the discrete latent variable, which corresponds to the digit type.
  • Right: change one of the continuous latent variable from -1 to 1. However, the decomposition is not clear.
  • Note: I found that batch normalization layers play an important role in InfoGAN. Without BN layers, the discrete latent variable tends to have a clear meaning as shown above, while the continuous variables have little effect. On the contrary, with BN layers, it's harder for the discrete variable to catch the digit type information and easier for continuous ones to find rotation in digits.

EigenGAN

Random samples (no truncation):

Traverse:


Mode Collapse Study

Mode collapse is a notorious problem in GANs, where the model can only generate a few modes of the real data. Various methods have been proposed to solve it. To study this problem, I experimented different methods on the following two datasets:

  • Ring8: eight gaussian distributions lying on a ring.
  • MNIST: handwritten digit dataset.

For simplicity, the model architecture in all experiments is SimpleMLP, namely a stack of nn.Linear layers, thus the quality of generated MNIST image may not be so good. However, this section aims to demonstrate the mode collapse problem rather than to achieve the best image quality.


GAN

200 steps 400 steps 600 steps 800 steps 1000 steps
1000 steps 2000 steps 3000 steps 4000 steps 5000 steps

On the Ring8 dataset, it can be clearly seen that all the generated data gather to only one of the 8 modes.

In the MNIST case, the generated images eventually collapse to 1.


GAN + R1 regularization

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 9000 steps

R1 regularization, a technique to stabilize the training process of GANs, can prevent mode collapse in vanilla GAN as well.


WGAN

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 9000 steps

WGAN indeed resolves the mode collapse problem, but converges much slower due to weight clipping.


WGAN-GP

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 9000 steps

WGAN-GP improves WGAN by replacing the hard weight clipping with the soft gradient penalty.

The pathological weights distribution in WGAN's discriminator does not appear in WGAN-GP, as shown below.


SNGAN

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 9000 steps

Note: The above SNGAN is trained with the vanilla GAN loss instead of the hinge loss.

SNGAN uses spectral normalization to control the Lipschitz constant of the discriminator. Even with the vanilla GAN loss, SNGAN can avoid mode collapse problem.


LSGAN

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 9000 steps

LSGAN uses MSE instead of Cross-Entropy as the loss function to overcome the vanishing gradients in vanilla GAN. However, it still suffers from the mode collapse problem. For example, as shown above, LSGAN fails to cover all 8 modes on the Ring8 dataset.

Note: Contrary to the claim in the paper, I found that LSGAN w/o batch normalization does not converge on MNIST.


VEEGAN

200 steps 400 steps 600 steps 800 steps 5000 steps
1000 steps 3000 steps 5000 steps 7000 steps 10000 steps

VEEGAN uses an extra network to reconstruct the latent codes from the generated data.


Run the code

Pretrained weights

The checkpoints and training logs are stored in xyfJASON/GANs-Implementations on huggingface.

Train

For GAN, WGAN-GP, SNGAN, LSGAN:

accelerate-launch scripts/train.py -c ./configs/xxx.yaml

For WGAN (weight clipping), InfoGAN, VEEGAN, CGAN, ACGAN and EigenGAN, use the scripts with corresponding name instead:

accelerate-launch scripts/train_xxxgan.py -c ./configs/xxx.yaml

Sample

Unconditional GANs:

accelerate-launch scripts/sample.py \
    -c ./configs/xxx.yaml \
    --weights /path/to/saved/ckpt/model.pt \
    --n_samples N_SAMPLES \
    --save_dir SAVE_DIR

Conditional GANs:

accelerate-launch scripts/sample_cond.py \
    -c ./configs/xxx.yaml \
    --weights /path/to/saved/ckpt/model.pt \
    --n_classes N_CLASSES \
    --n_samples_per_class N_SAMPLES_PER_CLASS \
    --save_dir SAVE_DIR

EigenGAN:

accelerate-launch scripts/sample_eigengan.py \
    -c ./configs/xxx.yaml \
    --weights /path/to/saved/ckpt/model.pt \
    --n_samples N_SAMPLES \
    --save_dir SAVE_DIR \
    --mode MODE

Evaluate

Sample images following the instructions above and use tools like torch-fidelity to calculate FID / IS.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.