Giter Club home page Giter Club logo

cvae-gan-zoos-pytorch-beginner's Introduction

CVAE-GAN-zoos-PyTorch-Beginner

首先先感谢给小透明点赞的几个朋友。

中文讲解:

如果你是第一次接触AE自编码器和GAN生成对抗网络,那这将会是一个非常有用且效率的学习资源。所有的内容使用PyTorch编写,编写格式清晰,非常适合PyTorch新手作为学习资源。本项目的所有模型目前都是基于MNIST数据库进行图片生成。MNIST数据集是一个比较小,一个光CPU就能跑起来的小数据库。新人友好数据库。

本项目总共包含以下模型:AE(自编码器), DAE(降噪自编码器), VAE(变分自编码器), GAN(对抗生成网络), CGAN(条件对抗生成网络), DCGAN(深度卷积对抗生成网络), WGAN(Wasserstain 对抗生成网络), WGAN-GP(基于梯度惩罚的WGAN), VAE-GAN(变分自编码对抗生成网络), CVAE-GAN(条件变分自编码对抗生成网络)PS:部分英文翻译的中文是我自己编的,哈哈!

建议学习这些模型的顺序为:


运行AE.py的时候,会自动在同目录下生成data文件,这个文件是自动下载的MNIST数据集;还会生成img_AE,这个是每一个epoch模型生成的图片

AE和DAE是非常类似的,这两个不是生成模型,而是单纯的对数据进行压缩存储的网络,不是生成网络!在运行完AE.py之后,会生成一个AE_z2.pth模型存储文件,之后可以运行AE_test.py,会生成网络对手写数字的压缩,可以看到相同数字映射到2维度的时候,会聚集在一起,有一点像是聚类分析。是一种无监督学习。来看一下AE的编码图像:

可以看到相同的数字彼此更靠近

这个就是把28*28的手写数字映射x,y两个维度上,看一下x,y属于[-2,2]的这个正方形区域,分别对应哪些图片。这个图片有点那个味道,但是我们发现中心的数字更亮,那是因为这个网络只能编码,而不能生成,想看比较好的效果的可以直接去看CVAE-GAN生成的图片,很有feel,你也很快就能生成。


CGAN是conditional 虽然GAN是无监督学习,但是我们也可以加入标签信息,然后生成指定标签的图片:


VAE是AE模型家族中的一个生成模型,来看一下VAE生成的图像:


VAE-WGANGP是一个在本项目中效果最好的无需标签的无监督生成网络。图像清晰,而且过度均匀


CVAE-GAN是本项目中效果最好的,采用了标签信息的监督生成网络。这里本想尝试改成CVAE-WGANGP,但是因为加入了标签信息导致WGANGP的损失函数不知道如何处理标签的梯度惩罚,所以没能实现。但是CVAE-GAN的效果也是不错的。

首先我们可以生成指定的不同样式的数字:

you can generate any photos as you like

我们也可以观察一个数字是如何慢慢转化成另外一个数字的:

you can find out how a number change to a different one. It's interesting!


英文版本还没有写完2333 For beginner, this will be the best start for VAEs, GANs, and CVAE-GAN.

This contains AE, DAE, VAE, GAN, CGAN, DCGAN, WGAN, WGAN-GP, VAE-GAN, CVAE-GAN. All use PyTorch. All use MNIST dataset and you do not need download anything but this Github.

If you are new to GAN and AutoEncoder, I advice you can study these models in such a sequence.

1,GAN->DCGAN->WGAN->WGAN-GP

2,GAN->CGAN

3,AE->DAE->VAE

4 if you finish all above models, it time to study CVAE-GAN.

I have spent two days on rebuilding all these models using PyTorch and I believe you can do better and faster.

Let's see the results of CVAE-GAN:

you can generate any photos as you like

you can generate any photos as you like

you can find out how a number change to a different one. It's interesting!

you can find out how a number change to a different one. It's interesting!

cvae-gan-zoos-pytorch-beginner's People

Contributors

yixinchen-ai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cvae-gan-zoos-pytorch-beginner's Issues

Gan 网络问题

WDCGAN 里 main 函数里有个
for a in range(3)的循环
这个有什么作用么?

The lack of documents

I found this problem while running the program:FileNotFoundError: [Errno 2] No such file or directory: 'C:\Users\dell\Desktop\下载的程序代码\CVAE-GAN-zoos-PyTorch-Beginner-master\CVAE-GAN-zoos-PyTorch-Beginner-master\CVAE-GAN/CVAE-GAN-Discriminator.pth',Thank you very much for your answer!

CVAE-GAN的训练问题

在CVAE-GAN.py的第168-171行:
# 更新VAE(G)2
output = D(recon_data)
real_label = torch.ones(batch_size).to(device)
vae_loss2 = criterion(output,real_label)
其中第170行应该是real_label = torch.zeros....吧?
因为这里输入的recon_data是generator的输出,理应被discriminator判为fake呀?

代码中普遍存在的问题

在训练判别器的时候,生成器部分在生成fake sample的时候需要从计算图中断开,需要加上一个.detach(),貌似都没有加上。这些代码的运行效果也不是很尽如人意,我感觉整体上错误还是有不少

cVAE-GAN可能有问题

原谅我,我看cVAE-GAN的代码,这不还是用的VAE作为解码器吗。和CVAE本身没有关联吧?

不过我看到单独训练了一个分类器,不过cVAE框架是这个样子的吗?

VAE-GAN 代码直接运行报错

我的环境是torch==1.8.1,跑VAE例程或者其他例程直接可以跑,暂时没遇到问题,就是跑这个例程报错。

直接运行你的代码,报错如下:


RuntimeError Traceback (most recent call last)
Cell In[9], line 204
202 output = D(recon_data)
203 errVAE = criterion(output, real_label)
--> 204 errVAE.backward()
205 D_G_z2 = output.mean().item()
206 optimizerVAE.step()

File d:\ProgramData\Anaconda3\envs\pt\lib\site-packages\torch\tensor.py:245, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
236 if has_torch_function_unary(self):
237 return handle_torch_function(
238 Tensor.backward,
239 (self,),
(...)
243 create_graph=create_graph,
244 inputs=inputs)
--> 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File d:\ProgramData\Anaconda3\envs\pt\lib\site-packages\torch\autograd_init_.py:145, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
142 if retain_graph is None:
143 retain_graph = create_graph
--> 145 Variable.execution_engine.run_backward(
146 tensors, grad_tensors
, retain_graph, create_graph, inputs,
147 allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16, 1, 4, 4]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

感觉dae有点问题

img = img + noise * 0.1
code, decode = AE(img) # 将真实图片放入判别器中
loss = criterion(decode, img)
DAE的损失是不是应该拿加噪声前的输入和输出进行比较呢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.