Giter Club home page Giter Club logo

gan's Introduction

GAN

GANs are generative models: they create new data instances that resemble your training data. The Generator tries to produce data from some probability distribution. That would be you trying to reproduce the party’s tickets. The Discriminator acts like a judge. It gets to decide if the input comes from the generator or the actual training set. That would be the party’s security comparing your fake ticket with the actual ticket to find flaws in your design.

image

Image Source

The goal of the generator is to fool the discriminator, so the generative neural network is trained to maximize the final classification error (between actual and generated data) The goal of the discriminator is to detect fake generated data, so the discriminative neural network is trained to minimise the final classification error.

Loss

image

Dataset

The MNIST database of handwritten digits has a training set of 60,000 examples and a test set of 10,000 samples. I used pytorch datasets for downloading dataset :

train_dataset = datasets.MNIST('mnist/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('mnist/', train=False, download=True, transform=transform)

Model

In simple GAN, both discriminator and generator have simple architecture composed of fully connected layers, and we have leaky relu activation functions in the discriminator to prevent it from becoming zero and relu activation functions in the generator.

Train

Trainer class Does the main part of code which is training model, plot the training process and save model each n epochs.

I Defined Adam Optimizer with learning rate 0.0002.

Each generative model training step occurse in train_generator function, descriminator model training step in train_descriminator and whole trining process in train function.

Some Configurations

  • You can set epoch size : EPOCHS and batch size : BATCH_SIZE.
  • Set device that you want to train model on it : device(default runs on cuda if it's available)
  • You can set one of three verboses that prints info you want => 0 == nothing || 1 == model architecture || 2 == print optimizer || 3 == model parameters size.
  • Each time you train model weights and plot(if save_plots == True) will be saved in save_dir.
  • You can find a configs file in save_dir that contains some information about run.

Results

Trained 100 epochs:

epoch-170-loss-plot GAN

Trained 160 epochs:

epoch-160-loss-plot GAN-v1

gan's People

Contributors

sobhanshukueian avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.