Comments (4)
Since self.discriminator.trainable = False
is set after the discriminator is compiled, it will not affect the training of the discriminator. However since it is set before the combined model is compiled the discriminator layers will be frozen when the combined model is trained. You can read more about this here:
https://keras.io/getting-started/faq/ - Under 'How can I "freeze" Keras layers?'
If you print the weights of the discriminator (by adding the two lines below) during training you will also see that they change for each training iteration.
for layer in self.discriminator.layers:
print (layer.get_weights())
from keras-gan.
That is strange. Because even when I print the model summary after both the discriminator and combined model is compiled I see that the discriminator's layers are trainable and the discriminator's layers in the combined model are freezed. I don't really see why your approach does not work. The difference I see when I look at your code is that you don't define the generator and the discriminator in the same way that I do. In your _generator
method:
def _generator(inp_size):
"""
construct the generator model
:param inp_size: size of the input noise vector
:return: generator model
"""
# define the generator model
generator_model = Sequential()
# dense block for bringing spatial integrity
generator_model.add(Dense(128, activation='relu', input_shape=(inp_size,)))
....
generator_model.add(Conv2DTranspose(filters=3, kernel_size=(3, 3),
padding='same', activation='sigmoid'))
# final output of the model is 32 x 32 x 3 (cifar-10 image size)
return generator_model
You do not define inputs and outputs of your model. When you then define the combined model you refer to the inputs of the generator, which are not defined. comb = Model(inputs=gen.inputs, outputs=dis(gen.outputs))
. You can compare this to my implementation where I define the input (noise = Input(shape=noise_shape)
) and output (img = model(noise)
) of the model (Model(noise, img)
) before I return and compile it:
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(256, input_shape=noise_shape))
...
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=noise_shape)
img = model(noise)
return Model(noise, img)
You could try it this way and see if it fixes your problem. I have read about similar issues regarding model.trainable = True
before. You can read more here:
keras-team/keras#4674
from keras-gan.
@eriklindernoren Thank you for replying. As you pointed to, the faq says that after compiling the model, the trainable property doesn't affect the model.
But, here is the output of my discriminator.summary() ->
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 32, 32, 16) 448
_________________________________________________________________
conv2d_2 (Conv2D) (None, 32, 32, 16) 2320
_________________________________________________________________
conv2d_3 (Conv2D) (None, 32, 32, 16) 2320
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 16) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 32) 12832
_________________________________________________________________
conv2d_5 (Conv2D) (None, 16, 16, 32) 25632
_________________________________________________________________
conv2d_6 (Conv2D) (None, 16, 16, 32) 25632
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 32) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 8, 8, 64) 51264
_________________________________________________________________
conv2d_8 (Conv2D) (None, 8, 8, 64) 102464
_________________________________________________________________
conv2d_9 (Conv2D) (None, 8, 8, 64) 102464
_________________________________________________________________
flatten_1 (Flatten) (None, 4096) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 4096) 16384
_________________________________________________________________
dense_7 (Dense) (None, 1024) 4195328
_________________________________________________________________
batch_normalization_2 (Batch (None, 1024) 4096
_________________________________________________________________
dense_8 (Dense) (None, 256) 262400
_________________________________________________________________
batch_normalization_3 (Batch (None, 256) 1024
_________________________________________________________________
dense_9 (Dense) (None, 32) 8224
_________________________________________________________________
batch_normalization_4 (Batch (None, 32) 128
_________________________________________________________________
dense_10 (Dense) (None, 1) 33
=================================================================
Total params: 4,812,993
Trainable params: 0
Non-trainable params: 4,812,993
_________________________________________________________________
and, here is the output for the combined.summary() ->
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1_input (InputLayer) (None, 16) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 2176
_________________________________________________________________
dense_2 (Dense) (None, 256) 33024
_________________________________________________________________
dense_3 (Dense) (None, 512) 131584
_________________________________________________________________
dense_4 (Dense) (None, 1024) 525312
_________________________________________________________________
dense_5 (Dense) (None, 2048) 2099200
_________________________________________________________________
dense_6 (Dense) (None, 4096) 8392704
_________________________________________________________________
reshape_1 (Reshape) (None, 8, 8, 64) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 32) 100384
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 16, 16, 32) 50208
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 16, 16, 32) 50208
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 32, 32, 16) 12816
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 32, 32, 16) 6416
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 32, 32, 16) 6416
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 32, 32, 8) 1160
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 32, 32, 8) 584
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 32, 32, 3) 219
_________________________________________________________________
sequential_2 (Sequential) (None, 1) 4812993
=================================================================
Total params: 16,225,404
Trainable params: 11,412,411
Non-trainable params: 4,812,993
_________________________________________________________________ `
As apparent, the combined model works fine, but the discriminator has all the parameters set to be non-trainable.
I have a slightly different architecture than the one you created. But the overall implementation technique used is same as yours.
My code is here -> https://github.com/ARANIKC/adversarial-learning/blob/master/GAN/NetworkGenerator/GAN.py
It is a bit confusing to me. Is the keras summary method and the fit method inconsistent?
WDYT?
from keras-gan.
eriklindernoren thankyou for your patience explain , I am clear about this method!
from keras-gan.
Related Issues (20)
- when training CGAN, raise AttributeError: 'list' object has no attribute 'keys' in the following code HOT 5
- About pix2pix cannot be executed
- [Pix2Pix] Use fit_generator to speed up training process
- Why d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) ? HOT 7
- cGAN: Using Multi label with different number of classes
- DCGAN can only generate noise images HOT 5
- (WGAN_GP)BatchNormalization in critic
- Adversarial Autoencoder training procedure does not correspond to procedure described in paper
- SRGAN - Generated image got PINK overlay all the time ? How to solve this ? HOT 1
- StarGAN & StyleGAN any chance ?
- Cannot import 'Adam' for keras optimizers while testing acgan HOT 1
- ACGAN: Difference multiplication and concatenation of embedding and noise
- pix2pix download dataset
- ImportError: cannot import name 'DataLoader' from 'data_loader'
- Where can I find and download some pre-trained model?
- Project dependencies may have API risk issues
- SRGAN dataset is not available
- when training SGAN for cifar10 and mnist dataset, raised the "AttributeError: 'list' object has no attribute 'keys' "in the following code lones
- cGAN with multi-labels of multi-classes HOT 1
- Keras GAN
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras-gan.