Giter Club home page Giter Club logo

sky-generator's Introduction

About

This script generates new images of skies (including sunsets and sunrises) using the technique of generative adversarial networks (GAN), as described in the paper by Goodfellow et al. The images are enhanced with the laplacian pyramid technique from Denton and Soumith Chintala et. al., implemented as a single G (generator) as described in the blog post by Anders Boesen Lindbo Larsen and Søren Kaae Sønderby. Most of the code is based on facebook's eyescream project. It also uses code from other repositories for weight initialization and LeakyReLUs. The images were collected from flickr.

Images

All of the following images were generated by a network trained with the command th train.lua --scale=32 --D_L1=0 --D_L2=1e-4 --D_iterations=2 --G_pretrained_dir=NONE. (Except for the training set images, obviously.)

64 generated 32x64 images

64 32x64 sky images, picked among 512 such as having the highest/best rating according to D.

512 generated 32x64 images

512 32x64 sky images generated by the neural net.

Training set

Training set images (for comparison).

Training progress video

Training progress from epoch 1 to 271 as a youtube video.

Nearest neighbours of generated 32x64 images

16 generated images (each pair left) and their nearest neighbours from the training set (each pair right). Distance was measured by 2-Norm (torch.dist()). The 16 selected images were the "best" ones among 512 images according to the rating by D, hence some similarity with the training set is expected.

Requirements

  • Torch with the following libraries (most of them are probably already installed by default):
    • nn (luarocks install nn)
    • paths (luarocks install paths)
    • image (luarocks install image)
    • optim (luarocks install optim)
    • cutorch (luarocks install cutorch)
    • cunn (luarocks install cunn)
    • cudnn (luarocks install cudnn)
    • dpnn (luarocks install dpnn)
    • display
  • Python 2.7 (only tested with that version)
    • scipy
    • numpy
    • scikit-image
  • CUDA capable GPU (~3GB memory or more) with cudnn3

Usage

  • Clone repository
  • Open dataset/.
  • python download.py - This will download images from flickr using the keywords sky, sky night, sunset and sunrise. Stop this script manually when it keeps complaining repeatedly that it has already downloaded certain images. That happens after downloading page ~150 of the search results or at about 15k images. Note that running this script will take several hours as time delays are set rather high.
  • python generate_dataset.py - This will take the downloaded images and augment them to 10 times the original number (by rotating, scaling and translating them).
  • ~/.display/run.js & . Starts display.
  • Open http://localhost:8000/ in your browser (plotting interface by display).
  • th train_v.lua --scale=32 --saveFreq=5 - This will train V, a network which shows you a rating for the generated images during the GAN-training. V works rather badly on these sky images, so there isn't really any point in letting it traing for long. Stop after the 5th epoch (i.e. after the first saving to <directory> message).
  • (optional) th pretrain_g.lua --scale=32 - This pretrains G in an autoencoder fashion, which can reduce the number of necessary training epochs a bit. You will have to stop this script manually after a saving to <directory> message.
  • th train.lua --scale=32 --D_L1=0 --D_L2=1e-4 --D_iterations=2. Trains the GAN network. Let this run however long you want. (At least 200 epochs recommended.)
  • th sample.lua - This will generate images from your trained network and save them in samples/. (Random images, best images as rated by D, worst images, training set images.) Add --neighbours to sample the nearest neighbours of generated images (takes a long time). Add --runs=10 to sample 10 times the number of images.

Training and architecture

Adam was used as the optimizer. Batch size was 32.

G:

local model = nn.Sequential()
model:add(nn.Linear(noiseDim, 64*8*16))
model:add(nn.View(64, 8, 16))
model:add(nn.PReLU(nil, nil, true))

model:add(nn.SpatialUpSamplingNearest(2))
model:add(cudnn.SpatialConvolution(64, 128, 5, 5, 1, 1, (5-1)/2, (5-1)/2))
model:add(nn.SpatialBatchNormalization(128))
model:add(nn.PReLU(nil, nil, true))

model:add(nn.SpatialUpSamplingNearest(2))
model:add(cudnn.SpatialConvolution(128, 64, 5, 5, 1, 1, (5-1)/2, (5-1)/2))
model:add(nn.SpatialBatchNormalization(64))
model:add(nn.PReLU(nil, nil, true))

model:add(cudnn.SpatialConvolution(64, dimensions[1], 3, 3, 1, 1, (3-1)/2, (3-1)/2))
model:add(nn.Sigmoid())

where noiseDim was 100 and dimensions[1] was 3.

D:

local conv = nn.Sequential()
conv:add(nn.SpatialConvolution(dimensions[1], 128, 3, 3, 1, 1, (3-1)/2))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.SpatialDropout(0.2))
conv:add(nn.SpatialAveragePooling(2, 2, 2, 2))

conv:add(nn.SpatialConvolution(128, 128, 3, 3, 1, 1, (3-1)/2))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.SpatialDropout(0.2))
conv:add(nn.SpatialAveragePooling(2, 2, 2, 2))

conv:add(nn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.SpatialDropout(0.2))
conv:add(nn.SpatialAveragePooling(2, 2, 2, 2))

conv:add(nn.SpatialConvolution(256, 256, 3, 3, 1, 1, (3-1)/2))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.SpatialDropout())

conv:add(nn.View(256 * 0.25 * 0.25 * 0.25 * dimensions[2] * dimensions[3]))
conv:add(nn.Linear(256 * 0.25 * 0.25 * 0.25 * dimensions[2] * dimensions[3], 1024))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.Dropout())
conv:add(nn.Linear(1024, 512))
conv:add(nn.PReLU(nil, nil, true))
conv:add(nn.Dropout())
conv:add(nn.Linear(512, 1))
conv:add(nn.Sigmoid())

where dimensions[1] was 3, dimensions[2] was 32 and dimensions[3] was 64.

sky-generator's People

Contributors

aleju avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.