Giter Club home page Giter Club logo

Comments (4)

eriklindernoren avatar eriklindernoren commented on July 27, 2024 2

Since self.discriminator.trainable = False is set after the discriminator is compiled, it will not affect the training of the discriminator. However since it is set before the combined model is compiled the discriminator layers will be frozen when the combined model is trained. You can read more about this here:
https://keras.io/getting-started/faq/ - Under 'How can I "freeze" Keras layers?'

If you print the weights of the discriminator (by adding the two lines below) during training you will also see that they change for each training iteration.

    for layer in self.discriminator.layers:
        print (layer.get_weights())

from keras-gan.

eriklindernoren avatar eriklindernoren commented on July 27, 2024 2

That is strange. Because even when I print the model summary after both the discriminator and combined model is compiled I see that the discriminator's layers are trainable and the discriminator's layers in the combined model are freezed. I don't really see why your approach does not work. The difference I see when I look at your code is that you don't define the generator and the discriminator in the same way that I do. In your _generator method:

def _generator(inp_size):
    """
    construct the generator model
    :param inp_size: size of the input noise vector
    :return: generator model
    """

    # define the generator model
    generator_model = Sequential()

    # dense block for bringing spatial integrity
    generator_model.add(Dense(128, activation='relu', input_shape=(inp_size,)))
    ....
    generator_model.add(Conv2DTranspose(filters=3, kernel_size=(3, 3),
                                        padding='same', activation='sigmoid'))

    # final output of the model is 32 x 32 x 3 (cifar-10 image size)
    return generator_model

You do not define inputs and outputs of your model. When you then define the combined model you refer to the inputs of the generator, which are not defined. comb = Model(inputs=gen.inputs, outputs=dis(gen.outputs)). You can compare this to my implementation where I define the input (noise = Input(shape=noise_shape)) and output (img = model(noise)) of the model (Model(noise, img)) before I return and compile it:

    def build_generator(self):

        noise_shape = (100,)

        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        ...
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

You could try it this way and see if it fixes your problem. I have read about similar issues regarding model.trainable = True before. You can read more here:
keras-team/keras#4674

from keras-gan.

akanimax avatar akanimax commented on July 27, 2024

@eriklindernoren Thank you for replying. As you pointed to, the faq says that after compiling the model, the trainable property doesn't affect the model.
But, here is the output of my discriminator.summary() ->

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 32, 32, 16)        448       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 16)        2320      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 32, 32, 16)        2320      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 16)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 32)        12832     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 32)        25632     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 16, 16, 32)        25632     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 8, 8, 64)          51264     
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 8, 8, 64)          102464    
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 8, 8, 64)          102464    
_________________________________________________________________
flatten_1 (Flatten)          (None, 4096)              0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 4096)              16384     
_________________________________________________________________
dense_7 (Dense)              (None, 1024)              4195328   
_________________________________________________________________
batch_normalization_2 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_8 (Dense)              (None, 256)               262400    
_________________________________________________________________
batch_normalization_3 (Batch (None, 256)               1024      
_________________________________________________________________
dense_9 (Dense)              (None, 32)                8224      
_________________________________________________________________
batch_normalization_4 (Batch (None, 32)                128       
_________________________________________________________________
dense_10 (Dense)             (None, 1)                 33        
=================================================================
Total params: 4,812,993
Trainable params: 0
Non-trainable params: 4,812,993
_________________________________________________________________

and, here is the output for the combined.summary() ->

 _________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1_input (InputLayer)   (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               2176      
_________________________________________________________________
dense_2 (Dense)              (None, 256)               33024     
_________________________________________________________________
dense_3 (Dense)              (None, 512)               131584    
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              525312    
_________________________________________________________________
dense_5 (Dense)              (None, 2048)              2099200   
_________________________________________________________________
dense_6 (Dense)              (None, 4096)              8392704   
_________________________________________________________________
reshape_1 (Reshape)          (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 32)        100384    
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 16, 16, 32)        50208     
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 16, 16, 32)        50208     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 32, 32, 16)        12816     
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 32, 32, 16)        6416      
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 32, 32, 16)        6416      
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 32, 32, 8)         1160      
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 32, 32, 8)         584       
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 32, 32, 3)         219       
_________________________________________________________________
sequential_2 (Sequential)    (None, 1)                 4812993   
=================================================================
Total params: 16,225,404
Trainable params: 11,412,411
Non-trainable params: 4,812,993
_________________________________________________________________ `

As apparent, the combined model works fine, but the discriminator has all the parameters set to be non-trainable.

I have a slightly different architecture than the one you created. But the overall implementation technique used is same as yours.
My code is here -> https://github.com/ARANIKC/adversarial-learning/blob/master/GAN/NetworkGenerator/GAN.py

It is a bit confusing to me. Is the keras summary method and the fit method inconsistent?
WDYT?

from keras-gan.

hujinsen avatar hujinsen commented on July 27, 2024

eriklindernoren thankyou for your patience explain , I am clear about this method!

from keras-gan.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.