nicola-decao / s-vae-pytorch Goto Github PK
View Code? Open in Web Editor NEWPytorch implementation of Hyperspherical Variational Auto-Encoders
Home Page: http://arxiv.org/abs/1804.00891
License: MIT License
Pytorch implementation of Hyperspherical Variational Auto-Encoders
Home Page: http://arxiv.org/abs/1804.00891
License: MIT License
Hello,
The torch.distributions are not differentionable and hence the gradients do not propagate to encoder because of reparametrization using torch.distributions.
Hence the model is not training as expected.
I even checked the backpropagation and the KL loss doesnot seem to have gradients.
Can you please confirm this? Is the model training?
Best regards,
SB
Hi, I have a question about the loss computation in the MNIST example.
According to your paper, the gradients should have two terms, one named g_rep and the other named g_cor. The g_cor is computed using log derivative trick which is like the reinforce/score function estimator. But in L151, I think only the g_rep term is considered. Did I miss something? Shouldn't L151 be like
loss = loss_recon + loss_recon.detach() * log_g_div_r(eps, ) + loss_KL
Hello,
There is a problem with the scale neuron in vmf distribution, don't know if it's because of my data or not, but when using any type of activation function, the variance neuron will be below 0, so it will be clamped to 1, so the entropy of the distribution becomes constant.
It's also the case for power spherical which you have pointed out in the other thread.
I was wondering whether this is a bug or it's also because of the numerical instabilities in the distribution.
Hi,
I have read your paper and it's very interesting.
Now, I'm trying to comprehend your MNIST example implementation.
I found this program take kl_divergence and BCEWithLogitsLoss in loss function.
So I think this program compute a gradient of KL divergence and g_rep term through backpropagation, but not g_cor term.
Why don't you put (log p(x|z) * log g/r) term into loss function for g_cor calculation?
Or can I ask you where g_cor term is computed?
In VonMisesFisher.py
def __while_loop(self, b, a, d, shape):
b, a, d = [e.repeat(*shape, *([1] * len(self.scale.shape))) for e in (b, a, d)]
w, e, bool_mask = torch.zeros_like(b).to(self.device), torch.zeros_like(b).to(self.device), (torch.ones_like(b) == 1).to(self.device)
shape = shape + torch.Size(self.scale.shape)
while bool_mask.sum() != 0:
e_ = torch.distributions.Beta((self.__m - 1) / 2, (self.__m - 1) / 2).sample(shape[:-1]).reshape(shape).to(self.device)
u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device)
w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_)
log_t = (2 * a * b) - (1 - (1 - b) * e_)
accept = ((self.__m - 1) * log_t - log_t.exp() + d) > torch.log(u)
reject = 1 - accept
w[bool_mask * accept] = w_[bool_mask * accept]
e[bool_mask * accept] = e_[bool_mask * accept]
bool_mask[bool_mask * accept] = reject[bool_mask * accept]
while bool_mask.sum() != 0:
there is endless loop problem when zdim >2
The current implementation of the Bessel function is on CPU. Do you have any workaround?
Hi, thanks for your code. I think your research is significance. I am interseted in the performance of this work in link prediction on graphs. I will appreciate it if you can provide the code of this work.
Hi,
I run your vMF code with GPU, but encountered some problem as follows:
hyperspherical_vae/distributions/von_mises_fisher.py", line 132, in _kl_vmf_uniform
RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor
Does this code support GPU computation because recent generative models rely highly on deep networks.
Thanks.
Hi,
I see you have implemented the derivative of w.r.t. as
but the derivative (confirmed by wolfram alpha) is given by
Is there a particular reason for the extra in the gradient term or have I missed something ?
Thanks in advance,
Amanjit
s-vae-pytorch-master/examples/mnist.py
Modify the code:
def train(model, optimizer):
for epoch in range(100):
for i, (x_mb, y_mb) in enumerate(train_loader):
Epoch: 0 Loss: 169.04999
Epoch: 1 Loss: 157.91183
Epoch: 2 Loss: 149.1766
Epoch: 3 Loss: nan
Epoch: 4 Loss: nan
Epoch: 5 Loss: nan
Epoch: 6 Loss: nan
Epoch: 7 Loss: nan
Epoch: 8 Loss: nan
Epoch: 9 Loss: nan
Epoch: 10 Loss: nan
Epoch: 11 Loss: nan
Epoch: 12 Loss: nan
Hello,
When sampling from the VMF, sometimes the samples get super large, very far from the mean, even 1 million times bigger than the other samples from the distribution. I was wondering how can one avoid that.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.