Giter Club home page Giter Club logo

Comments (7)

f-dangel avatar f-dangel commented on August 16, 2024

Hi, thanks for your question.

your approach to use the converter functionality seems correct, and BackPACK is capable to compute individual gradients for LSTM layers. Do you have a small code snippet that reproduces the problem? This would be extremely helpful.

Felix

from backpack.

LB-bulb avatar LB-bulb commented on August 16, 2024

Hi, Thanks very much for your reply.
And I make a small code snippet. It may be a little long, but has the same problem.

from __future__ import print_function
import torch
from torch import nn, optim, autograd
from backpack import backpack, extend
from backpack.extensions import BatchGrad

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.ip_emb = torch.nn.Linear(1,8)
        self.leaky_relu = torch.nn.LeakyReLU(0.1)
        self.enc_lstm = torch.nn.LSTM(8,64,1)
        self.decode=extend(decode(),use_converter=True) #the part I am interested in
        
    def lossfun(self):
        return nn.MSELoss(reduction = 'sum')

    def forward(self,hist):
        _,(enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
        out = self.decode(enc)
        return out, enc

class decode(nn.Module):
    def __init__(self):
        super(decode, self).__init__()
        self.dec_lstm = torch.nn.LSTM(64, 8 ,batch_first=True)
        self.op = torch.nn.Linear(8,1)
    
    def forward(self,enc):
        h_dec, _ = self.dec_lstm(enc)
        fut_pred = self.op(h_dec)
        return fut_pred

net = Net()
mse_extended = extend(net.lossfun())
x=torch.zeros(15,5,1)
y=torch.zeros(1,5,1)
out,enc = net(x)

pred = net.decode(enc)
loss = mse_extended(pred, y)

with backpack(BatchGrad()):
    loss.backward(
        inputs=list(net.decode.parameters()), retain_graph=True, create_graph=True
    )
count=0
for name, weights in net.decode.named_parameters():
    count=count+1
    print(count)
    print('name',name)
    print('weights',weights.shape)
    print('weights',type(weights))
    print(weights.requires_grad)
    print(weights.grad_batch.shape)

By the way,the torch version is 1.9.1.I'm not sure if this has an effect.

from backpack.

f-dangel avatar f-dangel commented on August 16, 2024

Hi,

I was able to reproduce your problem with the snippet and looked at the code BackPACK executes by running with backpack(..., debug=True).

This revealed that BackPACK does not execute the BatchGrad extension on your decoder's LSTM layer (relevant output only):

[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7f352eeceb50> on Linear(in_features=8, out_features=1, bias=True)

The reason seems to be that by setting inputs=list(net.decode.parameters()) in the backward call, PyTorch's autodiff won't fire BackPACK's backward hook to execute on the LSTM but stop before. (I am not familiar with the logic of backward hook execution when inputs=... are specified.) If I comment out the inputs=..., BackPACK executes on the LSTM (relevant output only, see last row):

[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on MSELoss()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on Linear(in_features=8, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fac1398ed50> on LSTM(64, 8, batch_first=True)

Hope this helps.
Felix

from backpack.

preminstrel avatar preminstrel commented on August 16, 2024

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

from backpack.

Chelsea-abab avatar Chelsea-abab commented on August 16, 2024

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like
ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

from backpack.

preminstrel avatar preminstrel commented on August 16, 2024

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

@Chelsea-abab I used torch == 1.12.1. Maybe you can try this torch version?

from backpack.

Chelsea-abab avatar Chelsea-abab commented on August 16, 2024

@LB-bulb @f-dangel Hello, I have the same problem when reproducing Fishr. I solved the problem by reinstalling backpack-for-pytorch==1.3.0 (the older version).

@preminstrel Hi, I am also reproducing Fishr. I use torch==1.13.1 and backpack-for-pytorch==1.3.0, an error like ImportError: cannot import name '_grad_input_padding' from 'torch.nn.grad' occured. This is because _grad_input_padding was removed in higher version of torch. May I ask how you fix it?

@Chelsea-abab I used torch == 1.12.1. Maybe you can try this torch version?

@preminstrel Thanks! It does work!

from backpack.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.