Comments (7)
Hi, thanks for your question.
your approach to use the converter functionality seems correct, and BackPACK is capable to compute individual gradients for LSTM
layers. Do you have a small code snippet that reproduces the problem? This would be extremely helpful.
Felix
from backpack.
Hi, Thanks very much for your reply.
And I make a small code snippet. It may be a little long, but has the same problem.
from __future__ import print_function
import torch
from torch import nn, optim, autograd
from backpack import backpack, extend
from backpack.extensions import BatchGrad
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.ip_emb = torch.nn.Linear(1,8)
self.leaky_relu = torch.nn.LeakyReLU(0.1)
self.enc_lstm = torch.nn.LSTM(8,64,1)
self.decode=extend(decode(),use_converter=True) #the part I am interested in
def lossfun(self):
return nn.MSELoss(reduction = 'sum')
def forward(self,hist):
_,(enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
out = self.decode(enc)
return out, enc
class decode(nn.Module):
def __init__(self):
super(decode, self).__init__()
self.dec_lstm = torch.nn.LSTM(64, 8 ,batch_first=True)
self.op = torch.nn.Linear(8,1)
def forward(self,enc):
h_dec, _ = self.dec_lstm(enc)
fut_pred = self.op(h_dec)
return fut_pred
net = Net()
mse_extended = extend(net.lossfun())
x=torch.zeros(15,5,1)
y=torch.zeros(1,5,1)
out,enc = net(x)
pred = net.decode(enc)
loss = mse_extended(pred, y)
with backpack(BatchGrad()):
loss.backward(
inputs=list(net.decode.parameters()), retain_graph=True, create_graph=True
)
count=0
for name, weights in net.decode.named_parameters():
count=count+1
print(count)
print('name',name)
print('weights',weights.shape)
print('weights',type(weights))
print(weights.requires_grad)
print(weights.grad_batch.shape)
By the way,the torch version is 1.9.1.I'm not sure if this has an effect.
from backpack.
Hi,
I was able to reproduce your problem with the snippet and looked at the code BackPACK executes by running with backpack(..., debug=True)
.
This revealed that BackPACK does not execute the BatchGrad
extension on your decoder's LSTM
layer (relevant output only):
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on Linear(in_features=8, out_features=1, bias=True)
The reason seems to be that by setting inputs=list(net.decode.parameters())
in the backward
call, PyTorch's autodiff won't fire BackPACK's backward hook to execute on the LSTM
but stop before. (I am not familiar with the logic of backward hook execution when inputs=...
are specified.) If I comment out the inputs=...
, BackPACK executes on the LSTM
(relevant output only, see last row):
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on Linear(in_features=8, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on LSTM(64, 8, batch_first=True)
Hope this helps.
Felix
from backpack.
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0
(the older version).
from backpack.
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).
@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1
and backpack-for-pytorch==1.3.0
, an error like
ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because _grad_input_padding
was removed in higher version of torch. May I ask how you fix it?
from backpack.
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).@preminstrel Hi, I am also reproducing Fishr. I use
torch==1.13.1
andbackpack-for-pytorch==1.3.0
, an error likeImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because_grad_input_padding
was removed in higher version of torch. May I ask how you fix it?
@Chelsea-abab I used torch == 1.12.1
. Maybe you can try this torch version?
from backpack.
@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling
backpack-for-pytorch==1.3.0
(the older version).@preminstrel Hi, I am also reproducing Fishr. I use
torch==1.13.1
andbackpack-for-pytorch==1.3.0
, an error likeImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad'
occured. This is because_grad_input_padding
was removed in higher version of torch. May I ask how you fix it?@Chelsea-abab I used
torch == 1.12.1
. Maybe you can try this torch version?
@preminstrel Thanks! It does work!
from backpack.
Related Issues (20)
- Support for LayerNorm HOT 3
- Gradient of tensor output HOT 6
- Error when using the laplace package, a backpack module is raising issues HOT 6
- Support for Custom models? HOT 1
- pytorch 1.13 support HOT 2
- Extending `BCEWithLogitsLoss` to non-binary labels
- [Feature Request] Levenberg Marquardt HOT 1
- cannot import backpack nor extend HOT 9
- Are customized loss functions supported? HOT 10
- Optimizing the locations of the Jacobians HOT 5
- add support for torch 2.0? HOT 7
- Encountered node that may break second-order extensions HOT 2
- Second order extension HOT 2
- Container modules with advanced control flow & modules with multiple inputs HOT 23
- torch version < 2.x in `setup.cfg` HOT 2
- AdaptiveAvgPool not supported for 2nd order derivatives? HOT 4
- Missing implementation of supported layers for DiagHessian and BatchDiagHessian
- Facing error while Using DiagHessian for torchvision.models.resnet18 HOT 1
- Feature for backpack on VAEs HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from backpack.