Comments (5)
Oh, I see what you mean. Hm. I haven't tried this but I think you can do:
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(self.label_dim,), dtype='int32')
label_embedding = Embedding(self.num_classes, self.latent_dim, input_length=self.label_dim)(label) # output shape (batch_size, self.label_dim, self.latent_dim)
model_input = multiply([noise, label_embedding]) # output shape (batch_size, self.label_dim, self.latent_dim)
model_input = Flatten()(model_input) # output shape (batch_size, self.latent_dim * self.label_dim)
Then you would have to set model.add(Dense(256, input_dim=self.latent_dim))
to model.add(Dense(256, input_dim=self.latent_dim * self.label_dim))
on line 60.
Instead of flattening the tensor of shape (batch_size, self.label_dim, self.latent_dim)
it should also be possible to multiply along dimension 1, to get a tensor of shape (batch_size, self.latent_dim)
:
...
model_input = dot(model_input, axes=1) # output shape (batch_size, self.latent_dim)
I would try both ways and see which yields the best results.
from keras-gan.
Sure, in the current implementation of cGAN the generator and discriminator expects a label input of shape (x,)
, where x
is a integer value between 0 and self.num_classes
- 1. So for MNIST this integer value will be between 0 and 9. But this can be adapted for any classification task.
from keras-gan.
The problem is label shape is hardcoded as label = Input(shape=(1,), dtype='int32')
so trying to use any shape greater than 1 will cause an error during the model_input = multiply([noise, label_embedding])
operation. The error will be something like ValueError: Operands could not be broadcast together with shapes (x,) (y,)
from keras-gan.
Ah yes that appears to work. Then for the discriminator network it's the same code just swap self.latent_dim
with np.prod(self.img_shape)
and rework the training loop as necessary and it's good to go.
from keras-gan.
Is it common to use Embedding
for CGAN? why not just append label to flattened image?
Related:
#143
from keras-gan.
Related Issues (20)
- when training CGAN, raise AttributeError: 'list' object has no attribute 'keys' in the following code HOT 5
- About pix2pix cannot be executed
- [Pix2Pix] Use fit_generator to speed up training process
- Why d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) ? HOT 7
- cGAN: Using Multi label with different number of classes
- DCGAN can only generate noise images HOT 5
- (WGAN_GP)BatchNormalization in critic
- Adversarial Autoencoder training procedure does not correspond to procedure described in paper
- SRGAN - Generated image got PINK overlay all the time ? How to solve this ? HOT 1
- StarGAN & StyleGAN any chance ?
- Cannot import 'Adam' for keras optimizers while testing acgan HOT 1
- ACGAN: Difference multiplication and concatenation of embedding and noise
- pix2pix download dataset
- ImportError: cannot import name 'DataLoader' from 'data_loader'
- Where can I find and download some pre-trained model?
- Project dependencies may have API risk issues
- SRGAN dataset is not available
- when training SGAN for cifar10 and mnist dataset, raised the "AttributeError: 'list' object has no attribute 'keys' "in the following code lones
- cGAN with multi-labels of multi-classes HOT 1
- Keras GAN
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-gan.