shaofanl / cyclegan-keras Goto Github PK
View Code? Open in Web Editor NEWA Keras implementation of CycleGAN
A Keras implementation of CycleGAN
Hi,
Thanks for your valuable code. But I don't quite follow the logic for calculating the gan loss in the cyclegan.py. My understanding is the discriminator tries to tell the real one is real(1), fake one is fake(0). so in cyclegan.py
for _ in range(opt.d_iter):
_, D_loss_real_A, D_loss_fake_A, D_loss_real_B, D_loss_fake_B = \
self.D_trainner.train_on_batch([real_A, fake_A, real_B, fake_B],
[zeros, ones*0.9, zeros, ones*0.9])
should be
for _ in range(opt.d_iter):
_, D_loss_real_A, D_loss_fake_A, D_loss_real_B, D_loss_fake_B = \
self.D_trainner.train_on_batch([real_A, fake_A, real_B, fake_B],
[ones*0.9, zeros, ones*0.9, zeros])
While the generator tries to fool the system to treat fake one as real. so in cyclegan.py
_, G_loss_fake_B, G_loss_fake_A, G_loss_rec_A, G_loss_rec_B = \
self.G_trainner.train_on_batch([real_A, real_B],
[zeros, zeros, real_A, real_B, ])
should be
_, G_loss_fake_B, G_loss_fake_A, G_loss_rec_A, G_loss_rec_B = \
self.G_trainner.train_on_batch([real_A, real_B],
[ones, ones, real_A, real_B, ])
Have you tried same experiment e.g. horse2zebra using this code? How many epochs can you obtain reasonable results? I tried to run the code with tensorflow backend with slightly modification but never obtain correct results. I am wondering if this logic mentioned above has something to do with it. Please advice. Thanks.
I found in cyclegan.py, you are using opt.idloss as the loss weight of identity loss
if opt.idloss > 0:
G_trainner = Model([real_A, real_B],
[dis_fake_B, dis_fake_A, rec_A, rec_B, fake_B, fake_A])
G_trainner.compile(Adam(lr=opt.lr, beta_1=opt.beta1,),
loss=['MSE', 'MSE', 'MAE', 'MAE', 'MAE', 'MAE'],
loss_weights=[1, 1, opt.lmbd, opt.lmbd, opt.idloss ,opt.idloss])
but, in the original pytorch version, they are using
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
idt_A = self.netG_A(self.real_B)
loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # loss part?
So I think the weight of identity loss should be opt.idloss*opt.lmbd
?
isn't there freezing of the D weights in your implementation when training G , or am I missing it?
In resnet.py:
def resnet_6blocks(input_shape, output_nc, ngf, **kwargs):
ks = 3
f = 7
p = (f-1)/2 --> p = (f-1)//2
to avoid float argument causing error later
ValueError: Dimensions must be equal, but are 3 and 4 for 'gen_A/instance_normalization2d_21/mul' (op: 'Mul') with input shapes: [1,3,1,1], [?,4,128,64].
this project is great, i just search for a CycleGAN implemented in keras, when i run the demo, i encounter a error "expected input_3 to have shape (None, 3, 128, 128) but got array with shape (1, 128, 128, 3)" accured in "fake_A_pool.extend(self.BtoA.predict(real_B))" , it seems there are some problems with the array dimensions conversion, can you confirm it, thank you very much.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.