Comments (4)
Hello @zhan-xu , sorry for my late reply. I fixed the backprop for the vSymEig function by:
- Removing the inplace operation on the norm
- Adding an Epsilon value in the square root of the norm computation to avoid undefined gradients at 0.
I did an optimization test using an Adam optimizer and an L1 loss on the eigenvalues and eigenvectors and it looks good now:
`
from torch-vectorized.
Maybe it is not about the shape. Here is a bit more information. I tried the following code
import torch
from torchvectorized.utils import sym
from torchvectorized.vlinalg import vSymEig
# Random batch of volumetric 3x3 symmetric matrices of size 16x9x32x32x32
input = sym(torch.rand(16, 9, 32, 32, 32))
input.requires_grad=True
# Output eig_vals with size: 16x3x32x32x32 and eig_vecs with size 16,3,3,32,32,32
eig_vals, eig_vecs = vSymEig(input, eigen_vectors=True)
gt_vals = torch.randn((16, 3, 32, 32, 32))
gt_vecs = torch.randn((16, 3, 3, 32, 32, 32))
#loss = torch.nn.functional.mse_loss(eig_vals, gt_vals) # this line goes without error
loss = torch.nn.functional.mse_loss(eig_vecs, gt_vecs) # this line has error
loss.backward()
The problem seems happen when we add supervision to the eig_vecs.
from torch-vectorized.
The problem is perhaps related to https://github.com/banctilrobitaille/torch-vectorized/blob/master/torchvectorized/vlinalg.py#L55
When commenting out line 56, it works. btw why here do you replace 0 with 1? Shouldn't replace it with a small number like 1e-10?
from torch-vectorized.
As a walk-round, I change line 56 into
u0 = u0 / (norm+1e-10)
u1 = u1 / (norm+1e-10)
u2 = u2 / (norm+1e-10)
Seems working so far...
from torch-vectorized.
Related Issues (5)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torch-vectorized.