Giter Club home page Giter Club logo

glow-pytorch's Introduction

Glow

This is pytorch implementation of paper "Glow: Generative Flow with Invertible 1x1 Convolutions". Most modules are adapted from the offical TensorFlow version openai/glow.

TODO

  • Glow model. The model is coded as described in original paper, some functions are adapted from offical TF version. Most modules are tested.
  • Trainer, builder and hparams loaded from json.
  • Infer after training
  • Test LU_decomposed 1x1 conv2d

Scripts

  • Train a model with
    train.py <hparams> <dataset> <dataset_root>
    
  • Generate z_delta and manipulate attributes with
    infer_celeba.py <hparams> <dataset_root> <z_dir>
    

Training result

Currently, I trained model for 45,000 batches with hparams/celeba.json using CelebA dataset. In short, I trained with follwing parameters

HParam Value
image_shape (64, 64, 3)
hidden_channels 512
K 32
L 3
flow_permutation invertible 1x1 conv
flow_coupling affine
batch_size 12 on each GPU, with 4 GPUs
learn_top false
y_condition false
  • Download pre-trained model from Dropbox

Reconstruction

Following are some samples at training phase. Row 1: reconstructed, Row 2: original.

Manipulate attribute

Use the method decribed in paper to calculate z_pos and z_neg for a given attribute. And z_delta = z_pos - z_neg is the direction to manipulate the original image.

  • manipulate Smiling (from negative to positive):

  • manipulate Young (from negative to positive):

  • manipulate Pale_Skin (from negative to positive):

  • manipulate Male (from negative to positive):

Issues

There might be some errors in my codes. Please help me to figure out.

glow-pytorch's People

Contributors

chaiyujin avatar kolchinski avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

glow-pytorch's Issues

Mathematical Formulation of Objective Function

    def normal_flow(self, x, y_onehot):
        pixels = thops.pixels(x)
        z = x + torch.normal(mean=torch.zeros_like(x),
                             std=torch.ones_like(x) * (1. / 256.)) #original input added with a small noise
        logdet = torch.zeros_like(x[:, 0, 0, 0])
        logdet += float(-np.log(256.) * pixels)   # ??? How should we add this here?
        # encode
        z, objective = self.flow(z, logdet=logdet, reverse=False)
        # prior
        mean, logs = self.prior(y_onehot)
        objective += modules.GaussianDiag.logp(mean, logs, z)

        if self.hparams.Glow.y_condition:
            y_logits = self.project_class(z.mean(2).mean(2))
        else:
            y_logits = None

        # return
        nll = (-objective) / float(np.log(2.) * pixels)
        return z, nll, y_logits

This normal_flow() function is the core forward propagation part of Glow.
"nll" here is just passed to this line in trainer.py

loss_generative = Glow.loss_generative(nll)

And Glow.loss_generative is just a static function (just take the mean)

@staticmethod
    def loss_generative(nll):
        # Generative loss
        return torch.mean(nll)

So basically nll is just the loss.
Then move our attention to "objective", from my understanding of each nn.Module in this project, it's the sum of log determinant of each nn.Module .

Then let's take a broad view of the whole model.
We start from z (input x added with some noise), go through many transformations So, the final output

Suppose p() represents the likelihood of each instance, and x represents a real image from the dataset, the object of this generative model should be The right hand side is always smaller than the left side, so the objective is just to enlarge the rhs. So the objective function of optimization is

But the computation of "nll" shows that it's actually
Is this really correct?

And by the way, what's the point of this line?

logdet += float(-np.log(256.) * pixels)

Reconstructed images are not like the input images

Hi,

First of all, thank you for a nice repo.

I am trying to map an image x to a latent representation z and then map back to its reconstruction x_hat with a trained model, but x and x_hat are not similar at all.

I understand that it may not be totally identical due to Split2d layer, but the degree of difference is way too severe.

I ran the training script, and the reconstructed images shown in Tensorboard are much similar to their corresponding inputs.

Here are rough snippets that can reproduce my problem.

For defining and loading dataset and model,

from torchvision import transforms
from glow.config import JsonConfig
from glow.builder import build
from glow.trainer import Trainer
from glow.utils import load
import vision

hparams = JsonConfig('hparams/celeba.json')

transform = transforms.Compose([
    transforms.CenterCrop(hparams.Data.center_crop),
    transforms.Resize(hparams.Data.resize),
    transforms.ToTensor()])
dataset = vision.Datasets['celeba']
dataset = dataset('<path-to-CelebA>', transform=transform)

built = build(hparams, True)
load('glow_celeba.ckpt.pkg', built['graph'])
graph = built['graph']
graph. eval()

Now I want to encode a batch of images and decode them back.

img_x = torch.stack([dataset[i]['x'] for i in range(12)])
img_x = img_x.cuda()

# encode
z, nll, y_logits = graph(img_x, y_onehot=None)

# decode
x_hat = graph(z, reverse=True)

When I visualize img_x and x_hat, I found a significant discrepancy.

import matplotlib.pyplot as plt

# original images
np_img_x = img_x.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_img_x[i])

image

# reconstructed images
np_x_hat = x_hat.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_x_hat[i])

image

As you can see, those images are very different.

As a sanity check, I ran unconditioned sampling. It gives reasonably fine images, especially with apt choice eps_std, so the model is well-trained.

out = graph(z=None, reverse=True, eps_std=0.6)
np_out = out.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_out[i])

image

Sampling Fails when moving to CIFAR10

Thanks for your pytorch version of Glow.
When I am running on CIFAR10 with your code, I find that the reconstruction is generally well, but the model fails to produce reasonable sampling results.
Reconstruction:
Screenshot from 2019-04-20 17:47:05
Sampling:
Screenshot from 2019-04-20 17:44:23
Do you have any idea how to fix this issue? Thanks.

[encoder part problem] generate_z()

In the inference part, we should use the selected image to generate the latent z value()

as show in 'z_base = graph.generate_z(dataset[base_index]["x"])'

But in the glow/models.py, we will repeat to (B, 1, 1, 1).

I do not why and the operator will need too much GPU memory. And I always got out of memory.

Could you help me?

Inplace operation error occurs with batch size 1

Hi yujin,
I'm applying generative flow on 3D model.
Your pytorch implementation is helpful.
Due to the memory size, I ran the code with batch size 1, then inplace operation error occurs like below.
(it is okay with batch size 2 or 4)

"RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"

I tried to figure out where this error comes from, but failed.
Do you have any idea?

Thanks in advance
Wonmo,

Question about Affine Coupling Parameterisation

I noticed in your affine coupling you are doing:

scale = F.sigmoid(scale + 2)

As opposed to the RealNVP parameterisation:

scale = scale.exp()

Did you find the sigmoid necessary for stability? It seems unusual to me, because it means that the feature-wise transformations can only contract/shrink the input space at every layer.

Hard to understand mean and logs in Split2d Module

class Split2d(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv = Conv2dZeros(num_channels // 2, num_channels)

    def split2d_prior(self, z):
        h = self.conv(z) # enlarge the channel number of z by doubling it
        return thops.split_feature(h, "cross") #split channel by odd an even

    def forward(self, input, logdet=0., reverse=False, eps_std=None):
        if not reverse:
            z1, z2 = thops.split_feature(input, "split") # split channel by first half and second half
            mean, logs = self.split2d_prior(z1)
            logdet = GaussianDiag.logp(mean, logs, z2) + logdet
            return z1, logdet
        else:
            z1 = input
            mean, logs = self.split2d_prior(z1)
            z2 = GaussianDiag.sample(mean, logs, eps_std)
            z = thops.cat_feature(z1, z2)
            return z, logdet

This is the Split2d class in glow/modules.py,
and my question is on the forward function.

mean, logs = self.split2d_prior(z1)

As I understand, self.split2d_prior() is only splitting on the channel by odd and even index. But the returned values are called mean and logs (log of square root of variance in this project).

Why is that? Why just splitting an image on the channel by its channel index can create 2 pieces with totally different meaning? Why can one be treated as mean and the other as logs?

Broken paralelization

When I try to run the model on several GPUs I am getting a numerical error:

Warning: NaN or Inf found in input tensor.
Warning: NaN or Inf found in input tensor.
Warning: NaN or Inf found in input tensor.

While running on a single GPU everything works just fine.

That indicated that there is an issue with parallelization

License

Hi Yujin,

Thanks for writing glow in pytorch. It's quite helpful for me and for others. If you're okay with doing so, do you mind putting a license on the repository to formally allow others to use the code?

to onnx?

First of all, I would like to thank the author for his contribution to open source. I have a question. Can the algorithm be directly transferred to onnx, or trt?

How to generate larger images?

Much appreciation for your effort to complementing glow in pytorch.

When I set and only set the image size to 128x128x3, I failed to train the model and a bug occurs after several steps:

RuntimeError: svd_cuda: the updating process of SBDSDC did not converge (error: 23)

When I generate larger images, are there any other modifications that I need to make?
Or how to generate larger images?

Best wishes!

NLL numbers on popular datasets

Can anyone report the NLL numbers on popular datasets like CIFAR-10, ImageNet, and CelebA? My results are different in scale from those reported in the literatures.

Dataset Issue

Thanks for providing a well document implementation of Glow. I am not able to figure out how did you prepare training data for CelebA. I am using your module CelebADataset but failed to produce similar images as you have shown. I tried this module on both aligned cropped version of CelebA and in-the-wild version as well.

Can you please share you exact data preparation methodology?

Version of Python and Pytorch? Using not for images?

Hi there,

I am just wondering what version of python and pytorch you used for this so I can be on the same page.

Also, would there need to be significant changes to the code base to have this work on 1D datasets of length 8, other than updating the json from [64,64,3] to [8,1]?

Thanks,
Michael

Question about the flow-based model.

Hello,

It's my first time to touch the flow-based model, and still wondering the function of encoder and decoder.

If I do something like this

out = reverse_flow(normal_flow(img))

then, the out will be exactly the same as image?

Infinite values in generated images

For generating new images, I sample z from a zero mean and 0.6 standard deviation normal distribution and feed it to the network with reverse=True argument.
But in many images, there are plenty of values greater than 1, even Inf value!
How can I handle this issue? What is the problem?

Thanks.

what sorts of loss values do you see?

In re-implementing this, I am finding it difficult to drop lower than 9ish - this seems like it's inevitable since the loss term is made up of (ignoring the pixel factor)

loss = -(loglikelihood_prior + logdet + discretization_correction) / log(2)

if z == 0, then the loglikelihood_prior is maximized, and you get a constant "-0.5 * log(2pi)"
thanks to the sigmoid(scale + 2) the logdet must always be negative, or 0 in the "best case",
and the discretization correction is a constant negative

I don't see how it's possible to get a loss of "3" or "1" seen in the glow paper graphs

How to adapt this code to torchvion.dataset?

I would like to use torchvision.dataset.MNIST to run the code since celebA dataset takes more time to train. Please could you tell me what changes to make if I want to train this code for the MNIST dataset

Positive logdet

Hi,
In one flow step, there are actnorm, permute, and coupling.
The sigmoid(scale+2) ensures coupling layer gives non-positive logdet. But there is no similar operation on actnorm and permute to ensure this.
I encounter positive logdet at final output when using network configuration K=20, L =3, because of this problem
I would appreciate it for any idea to solve this.

Invertibility of Glow

Hi Yujin,

Thanks for the nice implementation. From your reconstruction demo, it seems that the implementation cannot guarantee full invertibility of glow. Could you comment on it?

Best,

Kede

Question about LinearZeros and Conv2dZeros

Hi,
Thanks for the nice pytorch implementation. In the module of LinearZeros and Conv2dZeros, an additional scale term torch.exp(self.logs*self.logscale_factor) is multiplicated.
I found the same thing in the OpenAI tensorflow implementation, but I am not sure why.
Thank you

Recommendations for computations stability

I've performed some tests for reversible modules, I think the following recommendations should be helpful:

  • torch.inverse is not sufficiently accurate when using float, using if with double argument is preferred
    torch.inverse(x.double()).float(), this significantly reduces reconstruction error
  • there is torch.slogdet method, which more accurate then manual log+abs+det:
    dlogdet = torch.log(torch.abs(torch.det(self.weight))) * pixels

Thanks for porting code to torch :)

Inquiry about actnorm layer implementation

Hi

Thank you for publishing your code!

I have researched my work using Glow structure, and it is really helped to me.

I have a question about actnorm layer.

On this line, bias is a mean of input X, and then multiply -1.0 for calculating vars and logs.

bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0

In my knowledge, we should multiply -1.0 again to generate mean vector, but the code didn't multiply -1.0 and then copy bias data to learning parameter.

self.bias.data.copy_(bias.data)

I would like to ask I am correct or not.
Thank you once again and look forward to your response!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.