Giter Club home page Giter Club logo

Comments (5)

eriklindernoren avatar eriklindernoren commented on July 27, 2024 2

Oh, I see what you mean. Hm. I haven't tried this but I think you can do:

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(self.label_dim,), dtype='int32')

        label_embedding = Embedding(self.num_classes, self.latent_dim, input_length=self.label_dim)(label) # output shape (batch_size, self.label_dim, self.latent_dim)

        model_input = multiply([noise, label_embedding]) # output shape (batch_size, self.label_dim, self.latent_dim)
        model_input = Flatten()(model_input) # output shape (batch_size, self.latent_dim * self.label_dim)

Then you would have to set model.add(Dense(256, input_dim=self.latent_dim)) to model.add(Dense(256, input_dim=self.latent_dim * self.label_dim)) on line 60.

Instead of flattening the tensor of shape (batch_size, self.label_dim, self.latent_dim) it should also be possible to multiply along dimension 1, to get a tensor of shape (batch_size, self.latent_dim):

        ...
        model_input = dot(model_input, axes=1) # output shape (batch_size, self.latent_dim)

I would try both ways and see which yields the best results.

from keras-gan.

eriklindernoren avatar eriklindernoren commented on July 27, 2024

Sure, in the current implementation of cGAN the generator and discriminator expects a label input of shape (x,), where x is a integer value between 0 and self.num_classes - 1. So for MNIST this integer value will be between 0 and 9. But this can be adapted for any classification task.

from keras-gan.

dotjrt avatar dotjrt commented on July 27, 2024

The problem is label shape is hardcoded as label = Input(shape=(1,), dtype='int32') so trying to use any shape greater than 1 will cause an error during the model_input = multiply([noise, label_embedding]) operation. The error will be something like ValueError: Operands could not be broadcast together with shapes (x,) (y,)

from keras-gan.

dotjrt avatar dotjrt commented on July 27, 2024

Ah yes that appears to work. Then for the discriminator network it's the same code just swap self.latent_dim with np.prod(self.img_shape) and rework the training loop as necessary and it's good to go.

from keras-gan.

mrgloom avatar mrgloom commented on July 27, 2024

Is it common to use Embedding for CGAN? why not just append label to flattened image?

Related:
#143

from keras-gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.