Giter Club home page Giter Club logo

progressive-growing-torch's Introduction

Progressive Growing of GANs for Improved Quality, Stability, and Variation


[NOTE] This project was not goint well, so I made PyTorch implementation here. 🔥 [pggan-pytorch]


Torch implementation of PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION
YOUR CONTRIBUTION IS INVALUABLE FOR THIS PROJECT :)

image

NEED HELP

[ ] (1) Implementing Pixel-wise normalization layer
[ ] (2) Implementing pre-layer normalization (for equalized learning rate)  
(I have tried both, but failed to converge. Anyone can help implementing those two custom layers?)

Prerequisites

How to use?

[step 1.] Prepare dataset
CelebA-HQ dataset is not available yet, so I used 100,000 generated PNGs of CelebA-HQ released by the author.
The quality of the generated image was good enough for training and verifying the preformance of the code.
If the CelebA-HQ dataset is releasted in then near future, I will update the experimental result.
[download]

  • CAUTION: loading 1024 x 1024 image and resizing every forward process makes training slow. I recommend you to use normal CelebA dataset until the output resolution converges to 256x256.
---------------------------------------------
The training data folder should look like : 
<train_data_root>
                |--classA
                        |--image1A
                        |--image2A ...
                |--classB
                        |--image1B
                        |--image2B ...
---------------------------------------------

[step 2.] Run training

  • edit script/opts.lua to change training parameter. (don't forget to change path to training images)
  • run and enjoy! (Multi-threaded dataloading is supported.)
     $ python run.py

[step 3.] Visualization

  • to start display server:
     $ th server.lua
  • to check images during training procudure:
     $ <server_ip>:<port> at your browser

Experimental results

image

Transition experiment: (having trouble with transition from 8x8 -> 16x16 yet.)

What does the printed log mean?

(example)
[E:0][T:91][ 91872/202599]    errD(real): 0.2820 | errD(fake): 0.1557 | errG: 0.3838    [Res:   4][Trn(G):0.0%][Trn(D):0.0%][Elp(hr):0.2008]
  • E: epoch / T: ticks (1tick = 1000imgs) / errD,errG: loss of discrminator and generator
  • Res: current resolution of output
  • Trn: transition progress (if 100%, in training phase. if less than 100%, in transition phase using fade-in layer.)
    • first Trn : Transition of fade-in layer in generator.
    • second Trn : Transition of fade-in layer in discriminator.
  • Elp(hr): Elapsed Time (Hour)

To-Do List (will be implemented soon)

  • Equalized learning rate (weight normalization)
  • Support WGAN-GP loss

Compatability

  • cuda v8.0
  • Tesla P40 (you may need more than 12GB Memory. If not, please adjust the batch_table in pggan.lua)
  • python 2.7 / Torch7

Acknowledgement

Author

MinchulShin, @nashory
image

progressive-growing-torch's People

Contributors

joshua19881228 avatar nashory avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

progressive-growing-torch's Issues

fade-in layer test

------------
alpha:0
1-alpha:1
[1] grad sum:0.66109210252762
[2] alpha + 1-alpha:0.66109210252762
------------
alpha:0
1-alpha:1
[1] grad sum:-0.11638873815536
[2] alpha + 1-alpha:-0.11638873815536
------------
alpha:0
1-alpha:1
[1] grad sum:0.82711493968964
[2] alpha + 1-alpha:0.82711493968964
------------
alpha:0.00536
1-alpha:0.99464
[1] grad sum:-0.0023111030459404
[2] alpha + 1-alpha:-0.0023110407637432
[E:0][T:8][  8288/202599]    errD(real): 0.0312 | errD(fake): 0.0094 | errG: 0.8557    [Res:   8][Trn(G):0.5%][Trn(D):0.0%][Elp(hr):0.0261]
------------
alpha:0
1-alpha:1
[1] grad sum:0.5822246670723
[2] alpha + 1-alpha:0.5822246670723
------------
alpha:0
1-alpha:1
[1] grad sum:-1.6143760681152
[2] alpha + 1-alpha:-1.6143760681152
------------
alpha:0
1-alpha:1
[1] grad sum:0.58907866477966
[2] alpha + 1-alpha:0.58907866477966
------------
alpha:0.0054
1-alpha:0.9946
[1] grad sum:0.27052643895149
[2] alpha + 1-alpha:0.27052646130323
[E:0][T:8][  8320/202599]    errD(real): 0.0391 | errD(fake): 0.5431 | errG: 0.0795    [Res:   8][Trn(G):0.5%][Trn(D):0.0%][Elp(hr):0.0263]
------------
alpha:0
1-alpha:1
[1] grad sum:1.7794182300568
[2] alpha + 1-alpha:1.7794182300568
------------
alpha:0
1-alpha:1
[1] grad sum:-0.21309423446655
[2] alpha + 1-alpha:-0.21309423446655
------------
alpha:0
1-alpha:1
[1] grad sum:1.1774388551712
[2] alpha + 1-alpha:1.1774388551712
------------
alpha:0.00544
1-alpha:0.99456
[1] grad sum:1.3399093151093
[2] alpha + 1-alpha:1.339909250848
[E:0][T:8][  8352/202599]    errD(real): 0.3095 | errD(fake): 0.0158 | errG: 0.8003    [Res:   8][Trn(G):0.5%][Trn(D):0.0%][Elp(hr):0.0264]
------------
alpha:0
1-alpha:1
[1] grad sum:0.37424358725548
[2] alpha + 1-alpha:0.37424358725548
------------

PyTorch implementation released.

The results was not satisfying, so I newly made PyTorch version of this code.
Currently, I am working on this with PyTorch.
I think PyTorch version is way much faster and stable, but need to verify the code if it works well.

[PyTorch PGGAN]

tested model structures

I found PGGAN is very sensitive to the network structure, and I think it would be very helpful if the already tested network is shared.

About training time cost

It's a great implementation. And I can see that it only costed 12 minute to get 91872/202599. I wonder what your hardware setting is because it took about 1 hour to get 1568/596385 in my own 4 Titan xps. Or is there anything wrong with my experiment settings?

Initial model structure (when resolution = 4x4)

For test:

Generator structure:
nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
    (1): nn.SpatialFullConvolution(512 -> 512, 4x4)
    (2): nn.LeakyReLU(0.2)
    (3): nn.SpatialCrossMapLRN
    (4): nn.SpatialFullConvolution(512 -> 512, 3x3, 1,1, 1,1)
    (5): nn.LeakyReLU(0.2)
    (6): nn.SpatialBatchNormalization (4D) (512)
  }
  (2): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> output]
    (1): nn.SpatialFullConvolution(512 -> 3, 1x1)
    (2): nn.LeakyReLU(0.2)
    (3): nn.SpatialCrossMapLRN
  }
}
Discriminator structure:
nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> output]
    (1): nn.SpatialConvolution(3 -> 512, 1x1)
    (2): nn.LeakyReLU(0.2)
    (3): nn.SpatialBatchNormalization (4D) (512)
  }
  (2): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]
    (1): nn.MinibatchStatConcat
    (2): nn.SpatialConvolution(513 -> 512, 3x3, 1,1, 1,1)
    (3): nn.LeakyReLU(0.2)
    (4): nn.SpatialCrossMapLRN
    (5): nn.SpatialConvolution(512 -> 512, 4x4)
    (6): nn.LeakyReLU(0.2)
    (7): nn.SpatialCrossMapLRN
    (8): nn.View(512)
    (9): nn.Linear(512 -> 1)
    (10): nn.Sigmoid
  }
}

Bug Fixed: Weight Freezing after the transition

I found the weight freezing function WAS NOT WORKING so far, and I managed to know that.
I fixed the bug and confirmed the weights freeze properly. This was actually a critical issue for network stabilization, and now I hope the training goes well from now on :-)
Please see the below for more details.
(bf717d5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.