Giter Club home page Giter Club logo

s-vae-pytorch's People

Contributors

jmcorgan avatar nicola-decao avatar oskopek avatar tom-pelsmaeker avatar trdavidson avatar wcarvalho avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

s-vae-pytorch's Issues

Differentiability of the distribution objects.

Hello,

The torch.distributions are not differentionable and hence the gradients do not propagate to encoder because of reparametrization using torch.distributions.

Hence the model is not training as expected.

I even checked the backpropagation and the KL loss doesnot seem to have gradients.

Can you please confirm this? Is the model training?

Best regards,
SB

Question about loss computation and gradients in the example.

Hi, I have a question about the loss computation in the MNIST example.

According to your paper, the gradients should have two terms, one named g_rep and the other named g_cor. The g_cor is computed using log derivative trick which is like the reinforce/score function estimator. But in L151, I think only the g_rep term is considered. Did I miss something? Shouldn't L151 be like
loss = loss_recon + loss_recon.detach() * log_g_div_r(eps, ) + loss_KL

Constant variance in VMF

Hello,

There is a problem with the scale neuron in vmf distribution, don't know if it's because of my data or not, but when using any type of activation function, the variance neuron will be below 0, so it will be clamped to 1, so the entropy of the distribution becomes constant.
It's also the case for power spherical which you have pointed out in the other thread.
I was wondering whether this is a bug or it's also because of the numerical instabilities in the distribution.

g_cor term computation

Hi,
I have read your paper and it's very interesting.
Now, I'm trying to comprehend your MNIST example implementation.

I found this program take kl_divergence and BCEWithLogitsLoss in loss function.
So I think this program compute a gradient of KL divergence and g_rep term through backpropagation, but not g_cor term.

Why don't you put (log p(x|z) * log g/r) term into loss function for g_cor calculation?
Or can I ask you where g_cor term is computed?

endless loop problem

In VonMisesFisher.py

def __while_loop(self, b, a, d, shape):

    b, a, d = [e.repeat(*shape, *([1] * len(self.scale.shape))) for e in (b, a, d)]
    w, e, bool_mask = torch.zeros_like(b).to(self.device), torch.zeros_like(b).to(self.device), (torch.ones_like(b) == 1).to(self.device)
    
    shape = shape + torch.Size(self.scale.shape)

    while bool_mask.sum() != 0:
        e_ = torch.distributions.Beta((self.__m - 1) / 2, (self.__m - 1) / 2).sample(shape[:-1]).reshape(shape).to(self.device)
        u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device)
        
        w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_)
        log_t = (2 * a * b) - (1 - (1 - b) * e_)

        accept = ((self.__m - 1) * log_t - log_t.exp() + d) > torch.log(u)
        reject = 1 - accept
        
        w[bool_mask * accept] = w_[bool_mask * accept]
        e[bool_mask * accept] = e_[bool_mask * accept]

        bool_mask[bool_mask * accept] = reject[bool_mask * accept]

while bool_mask.sum() != 0:

there is endless loop problem when zdim >2

CPU bound

The current implementation of the Bessel function is on CPU. Do you have any workaround?

The code of link prediction on graphs

Hi, thanks for your code. I think your research is significance. I am interseted in the performance of this work in link prediction on graphs. I will appreciate it if you can provide the code of this work.

GPU support

Hi,
I run your vMF code with GPU, but encountered some problem as follows:

hyperspherical_vae/distributions/von_mises_fisher.py", line 132, in _kl_vmf_uniform
RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

Does this code support GPU computation because recent generative models rely highly on deep networks.

Thanks.

Nan value for loss

s-vae-pytorch-master/examples/mnist.py
Modify the code:

def train(model, optimizer):
for epoch in range(100):
for i, (x_mb, y_mb) in enumerate(train_loader):

Hyper-spherical VAE

Epoch: 0 Loss: 169.04999
Epoch: 1 Loss: 157.91183
Epoch: 2 Loss: 149.1766
Epoch: 3 Loss: nan
Epoch: 4 Loss: nan
Epoch: 5 Loss: nan
Epoch: 6 Loss: nan
Epoch: 7 Loss: nan
Epoch: 8 Loss: nan
Epoch: 9 Loss: nan
Epoch: 10 Loss: nan
Epoch: 11 Loss: nan
Epoch: 12 Loss: nan

Numerical instability of the samples

Hello,
When sampling from the VMF, sometimes the samples get super large, very far from the mean, even 1 million times bigger than the other samples from the distribution. I was wondering how can one avoid that.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.