Giter Club home page Giter Club logo

rahulbhalley / progressive-growing-of-gans.pytorch Goto Github PK

View Code? Open in Web Editor NEW
50.0 5.0 5.0 17.94 MB

Unofficial PyTorch implementation of "Progressive Growing of GANs for Improved Quality, Stability, and Variation".

Home Page: https://arxiv.org/abs/1710.10196

License: MIT License

Python 99.95% Shell 0.05%
progressive-gan generative-adversarial-network pytorch discriminator generator wasserstein-gan wgan tensorboard optimal-transport

progressive-growing-of-gans.pytorch's Introduction

Progressive Growing of Generative Adversarial Network

This is PyTorch implementation of Progressive Growing GANs. The network is trainable on custom image dataset.

Place your dataset folder inside data folder. The training stats are added to repo folder as the training progresses.

Training Configuration

The network training parameters can be configured with following flags.

General settings

  • --train_data_root - set your data sirectory
  • --random_seed - random seed to reproduce the experiments
  • --n_gpu - multiple GPU training

Training parameters

  • --lr - learning rate
  • --lr_decay - learning rate decay at every resolution transition
  • --eps_drift - coefficient for the drift loss
  • --smoothing - smoothing factor for smoothed generator
  • --nc - number of input channel
  • --nz - input dimension of noise
  • --ngf - feature dimension of final layer of generator
  • --ndf - feature dimension of first layer of discriminator
  • --TICK - 1 tick = 1000 images = (1000/batch_size) iteration
  • --max_resl - 10-->1024, 9-->512, 8-->256
  • --trns_tick - transition tick
  • --stab_tick - stabilization tick

Network structure

  • --flag_wn - use of equalized-learning rate
  • --flag_bn - use of batch-normalization (not recommended)
  • --flag_pixelwise - use of pixelwise normalization for generator
  • --flag_gdrop - use of generalized dropout layer for discriminator
  • --flag_leaky - use of leaky relu instead of relu
  • --flag_tanh - use of tanh at the end of the generator
  • --flag_sigmoid - use of sigmoid at the end of the discriminator
  • --flag_add_noise - add noise to the real image(x)
  • --flag_norm_latent - pixelwise normalization of latent vector (z)
  • --flag_add_drift - add drift loss

Optimizer setting

  • --optimizer - optimizer type
  • --beta1 - beta1 for adam
  • --beta2 - beta2 for adam

Display and save setting

  • --use_tb - enable tensorboard visualization
  • --save_img_every - save images every specified iteration
  • --display_tb_every - display progress every specified iteration

GPU Note

Make sure your machine has CUDA enabled GPU(s) if you want to train on GPUs. Change the --n_gpu flag to positive integral value <= available number of GPUs.

TODO

  • WGAN training methodology

Related Links

progressive-growing-of-gans.pytorch's People

Contributors

rahulbhalley avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

progressive-growing-of-gans.pytorch's Issues

It does not look like the models are initializing

I am having trouble running pggan on a fresh install of pytorch on Windows 10. Get the following error:

python pggan.py --ngpu 1 --train_data_root \Users\Powerpop\Desktop\pggan\wikiart\popart
Configuration
train_data_root: \Users\Powerpop\Desktop\pggan\wikiart\popart
random_seed: 1558492887
n_gpu: 8
lr: 0.001
lr_decay: 0.87
eps_drift: 0.001
smoothing: 0.997
nc: 3
nz: 512
ngf: 512
ndf: 512
TICK: 1000
max_resl: 9
trns_tick: 200
stab_tick: 100
flag_wn: True
flag_bn: False
flag_pixelwise: True
flag_gdrop: True
flag_leaky: True
flag_tanh: False
flag_sigmoid: False
flag_add_noise: True
flag_norm_latent: False
flag_add_drift: True
optimizer: adam
beta1: 0.0
beta2: 0.99
use_tb: True
save_img_every: 20
display_tb_every: 5
C:\Users\Powerpop\Desktop\pggan\custom_layers.py:114: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
kaiming_normal(self.conv.weight, a=calculate_gain('conv2d'))
Traceback (most recent call last):
File "pggan.py", line 364, in
pggan = PGGAN(config)
File "pggan.py", line 61, in init
self.G = Generator(config)
File "C:\Users\Powerpop\Desktop\pggan\network.py", line 88, in init
self.model = self.get_init_gen()
File "C:\Users\Powerpop\Desktop\pggan\network.py", line 136, in get_init_gen
self.module_names = get_module_names(model)
File "C:\Users\Powerpop\Desktop\pggan\network.py", line 66, in get_module_names
for key, val in model.state_dict().iteritems():
AttributeError: 'collections.OrderedDict' object has no attribute 'iteritems'

RuntimeError: randperm is only implemented for CPU

The code doesn't work in Pytorch v0.4.0 (latest stable release):

Exception KeyError: KeyError(<weakref at 0x7f1fd0a76aa0; to 'tqdm' at 0x7f1fd0aad6d0>,) in <bound method tqdm.__del__ of   0%|                                                                                                                                                                                                                                                 | 0/18750 [00:00<?, ?it/s]> ignored
Traceback (most recent call last):
  File "pggan.py", line 365, in <module>
    pggan.train()
  File "pggan.py", line 269, in train
    self.x.data = self.feed_interpolated_input(self.loader.get_batch())
  File "Progressive-Growing-of-GANs/dataloader.py", line 57, in get_batch
    dataIter = iter(self.dataloader)
  File "Progressive-Growing-of-GANs-py2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 451, in __iter__
    return _DataLoaderIter(self)
  File "Progressive-Growing-of-GANs-py2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 247, in __init__
    self._put_indices()
  File "Progressive-Growing-of-GANs-py2/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 295, in _put_indices
    indices = next(self.sample_iter, None)
  File "Progressive-Growing-of-GANs-py2/lib/python2.7/site-packages/torch/utils/data/sampler.py", line 138, in __iter__
    for idx in self.sampler:
  File "Progressive-Growing-of-GANs-py2/lib/python2.7/site-packages/torch/utils/data/sampler.py", line 51, in __iter__
    return iter(torch.randperm(len(self.data_source)).tolist())
RuntimeError: randperm is only implemented for CPU

Here people recommend writing device-agnostic code to avoid such problems. What would you suggest as a quick fix?

GPU have processes assigned but training time is taking as long as CPU

Using pytorch v 1.0.1, I was initially getting this error:

RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:1 and input b is on cuda:0

After using the register_buffer fix identified here (https://discuss.pytorch.org/t/tensors-are-on-different-gpus/1450/28) in the custom_layers.py file, I was able to get the program to run. GPU memory is being used, but the iterations are taking just as long as with CPU only.

Screen Shot 2019-04-04 at 9 30 16 AM

Do you have any idea as to why this would be?

interpolation

Thanks a lot for sharing!
Didn't try it yet, but will soon :)

Notice that what you refer to as "latent space interpolation" (and the animated gif) are not really interpolations.
You are generating random samples, but not interpolating between them.
To get the "morphing" effect, like in their movie, you need to interpolate (in the latent z space) between the random samples, and generate for each an image.

Resuming training

The code base does not support resuming training โ€ฆ and it doesn't save the model state in such a way that one could resume. The code saves the state_dict data for the generator and discriminator at tar files for some reason even though python just pickles the output (not tar).

Second, for a checkpoint to be useful for resuming training more data has to be stored - these include the epoch, model state_dict and optimizer state_dict

torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)


model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.