Giter Club home page Giter Club logo

keras-generative's Introduction

Keras VAEs and GANs

Keras implementation of various deep generative networks such as VAE and GAN.

Models

Standard models

  • Variational autoencoder (VAE) [Kingma et al. 2013]
  • Generative adversarial network (GAN or DCGAN) [Goodfellow et al. 2014]
  • Improved GAN [Salimans et al. 2016]
  • Energy-based GAN (EBGAN) [Zhao et al. 2016]
  • Adversarially learned inference (ALI) [Dumoulin et al. 2017]

Conditional models

  • Conditional variational autoencoder [Kingma et al. 2014]
  • CVAE-GAN [Bao et al. 2017]

Usage

Prepare datasets

First, download img_align_celeba.zip and list_attr_celeba.txt from CelebA webpage. Then, place these files to datasets and run create_database.py on databsets directory.

Training

# Standard models
python train.py --model=dcgan --epoch=200 --batchsize=100 --output=output

# Conditional models
python train_conditional.py --model=cvaegan --epoch=200 --batchsize=100 --output=output

References

  • Kingma et al., "Auto-Encoding Variational Bayes", arXiv preprint 2013.
  • Goodfellow et al., "Generative adversarial nets", NIPS 2014.
  • Salimans et al., "Improved Techniques for Training GANs", arXiv preprint 2016.
  • Zhao et al., "Energy-based generative adversarial network", arXiv preprint 2016.
  • Dumoulin et al. "Adversarially learned inference", ICLR 2017.
  • Kingma et al., "Semi-supervised learning with deep generative models", NIPS 2014.
  • Bao et al., "CVAE-GAN: Fine-Grained Image Generation through Asymmetric Training", arXiv preprint 2017.

keras-generative's People

Contributors

tatsy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

keras-generative's Issues

cvaegan diverge when dataset is cifar10

Hi, your "keras-generative" implementation is really fascinating.
When training the cvaegan.py on cifar10 with your code and hyper-setting parameters, the generator diverge and the generated image are noise -like. I haven't try celebA dataset.
Have you generated real-like image with cvaegan.py? Is there any setting i should have pay attention to?
I look forward your reply. Thanks in advance.

About VAELossLayer

Hello,

My question may seem a little simple but I couldn't figure out some parts. I'd appreciate if you give me an explanation.

You define vae_loss as below using a class. As far as I understand, the return of this class is a tensor of x_true.
vae_loss = VAELossLayer()([x_true, x_pred, z_avg, z_log_var])

Regarding the loss, shouldn't we use lossfun or add_loss instead of zero_loss? I didn't understand how the model will train with a zero_loss tensor?
self.vae_trainer = Model(inputs=[x_true], outputs=[vae_loss])
self.vae_trainer.compile(loss=[zero_loss]...

And this is the class structure. It's like we don't use lossfun or add_loss anywhere in the code.

class VAELossLayer(Layer):
name = 'vae_loss_layer'

def __init__(self, **kwargs):
    self.is_placeholder = True
    super(VAELossLayer, self).__init__(**kwargs)

def lossfun(self, x_true, x_pred, z_avg, z_log_var):
    rec_loss = K.mean(K.square(x_true - x_pred))
    kl_loss = K.mean(-0.5 * K.sum(1.0 + z_log_var - K.square(z_avg) - K.exp(z_log_var), axis=-1))
    return rec_loss + kl_loss

def call(self, inputs):
    x_true = inputs[0]
    x_pred = inputs[1]
    z_avg = inputs[2]
    z_log_var = inputs[3]
    loss = self.lossfun(x_true, x_pred, z_avg, z_log_var)
    self.add_loss(loss, inputs=inputs)

    return x_true

Zero_loss function for EBGAN

Can you please explain how the zero loss function exactly works?
def zero_loss(y_true, y_pred):
return K.zeros_like(y_true)

Isn't the loss always zero if we use this function?

GPU requirements

What are the version of Cuda and cudnn required for this code to run.
I am able to run it on cudnn 6 and cuda 7 but not on cudnn 8 and cuda 9.

what is ConditionalDataset()?

in mnist data load, it requires ConditionalDataset, which is not in the same folder. I am wondering what that is?

from .datasets import ConditionalDataset

def load_data():
(x_train, y_train), _ = keras.datasets.mnist.load_data()

x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0)
x_train = (x_train[:, :, :, np.newaxis] / 255.0).astype('float32')
y_train = keras.utils.to_categorical(y_train)
y_train = y_train.astype('float32')

**datasets = ConditionalDataset()**
datasets.images = x_train
datasets.attrs = y_train
datasets.attr_names = [str(i) for i in range(10)]

return datasets

reconstruct exact input

Hi,
Thank you for your work.
Currently, I am working with the CVAEGAN and reconstruct face with emotion, and I want to exactly reconstruct the input as normal autoencoder does and keep the discriminator as the same time. My thought is to change the discriminator binary_crossentropy (here) loss by mse?
Is it right?

Thank you.
tsly

How can I generate new image through the hdf5?

@tatsy
Hi! How can use the weights in dec_trainer.hdf5 ?
I add a function in cvaegan.py :
def make_my_image(self,samples):
self.f_dec.load_weights('/home/liuweihuang/keras-generative/output/cvaegan/weights/epoch_00100/dec_trainer.hdf5')
self.save_my_images(samples, '/home/liuweihuang/keras-generative/myImg.png')

And it occurs :
ValueError: You are trying to load a weight file containing 4 layers into a model with 11 layers.

About the loss in CVAE-GAN

Hi, your implementation is really fascinating~
I have some questions about the loss function part of CVAE-GAN. In the paper, L_GD and L_GC have the loss weight 1e-3, but it is not found in your implementation. Does this affect the training of the generator? I have trained 30 epochs on celebB dataset, the generator accuracy is always 0.
Looking forward your reply.

Testing cvaegan

Hii, I have trained cvaegan on my dataset but I want to test it as an image translation model for a single image and also I want to use classifier function separately for one input image. I tried doing it but I got the following error for the classification model.

ValueError: You are trying to load a weight file containing 1 layers into a model with 10 layers.

g_acc close to 0 and d_acc close to 1

Hello,
Thanks for your work, it's fascinating. But I have a problem when training CVAE-GAN.
When does the model converge?g_acc close to 1 and d_acc close to 0?
And in mnist, I found that g_acc always reduces to 0 very fast, and it seems that the discriminator dominants the generator. So how to solve this problem?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.