Giter Club home page Giter Club logo

sgconv's People

Contributors

ctlllll avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sgconv's Issues

checkpoint loading issue

Greetings,

I ran into some trouble loading an SGConv network from checkpoint. Particularly, I encountered the following problem:

size mismatch for kernel_norm: copying a param with shape torch.Size([2, 256, 1]) from checkpoint, the shape in current model is torch.Size([256, 1]).

A few things I have noticed:

  • The above error can be replicated by running the following code in ipynb.
  • In the code below, the error goes away if I forward the layer before loading.

Any pointers on what caused this error and how it can be solved are greatly welcome. Thanks!


import torch
from gconv_standalone import GConv

layer = GConv(
d_model=256,
d_state=64,
l_max=1_000_000,
bidirectional=True,
kernel_dim=32,
n_scales=None,
decay_min=2,
decay_max=2,
)

x = torch.randn(1, 256, 1000)
x = x.cuda()
layer.cuda()
y, k = layer(x, return_kernel=True)

path = './dummy_ckpt'
torch.save({
'state_dict': layer.state_dict()
}, path)

shell = GConv(
d_model=256,
d_state=64,
l_max=1_000_000,
bidirectional=True,
kernel_dim=32,
n_scales=None,
decay_min=2,
decay_max=2,
).cuda()

ckpt = torch.load(path)
shell.load_state_dict(ckpt['state_dict'])

The question about gconv.py

Hello. I read the code and have two questions.

  1. It seems that d_state is not used in the code? I am curious about this parameter.
  2. Why multiplier is between 1 and 4 by default? I think this need to be a value smaller than 1, like 1/2 in the paper.

Complex Tensors

Greetings,

May I kindly request an adaptation of the code where the convolution operation is done without resorting to fft? This is largely because torch DDP does not support complex tensors well (see pytorch/pytorch#80080), preventing the usage of the model in a distributed training scenario (related to a current issue #3).

Any updates or pointers to this or possible getarounds are greatly appreciated!

Update:

I managed to find the issue. It is not related to complex tensors as all SGConv parameters are either torch.float32 or torch.int64. The issue is related to self.kernel_norm_initialized, which is registered as torch.bool. Despite that the NCCL backend has supported torch.bool (pytorch/pytorch#41959), it seems it has been the cause of the issue. Changing this registry from torch.bool to torch.float32 resolved the problem.

2d filters

Hey, nice work!

I was wondering. It seems that in image tasks you convert the features to 1d and then apply the filter. Would it be possible to create 2d filters using the same idea? Did you try that?

Can't get it to run with multi-GPU

Here is my code:

import os
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

import argparse
from tensorboardX import SummaryWriter

gpu_devices = '0,1,2,3'
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices


device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = GConv(
    d_model=256,
    d_state=64,
    l_max=1_000_000,
    bidirectional=True,
    kernel_dim=32,
    n_scales=None,
    decay_min=2,
    decay_max=2,
)

net = nn.DataParallel(net)
net = net.to(device)
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('The number of parameters of model is', num_params)
                
x = torch.randn(1, 256, 1_000_000)
x = x.to(device)

y, k = net(x, return_kernel=True)

And here is the error I am getting:

IndexError: Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ec2-user/SageMaker/SGConv/gconv_standalone.py", line 416, in forward
self.kernel_list[i],
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 462, in getitem
idx = self._get_abs_string_index(idx)
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 445, in _get_abs_string_index
raise IndexError('index {} is out of range'.format(idx))
IndexError: index 0 is out of range

Have you tried even longer sequences? Like billions of tokens?

Hello. I actually ended up giving a presentation on this paper because I found it so fascinating. I know this paper performs very well on long-range tasks in the 200k range, but have you tried it on even longer-range tasks than that, like in the billions of tokens away?

Reproduce results of LRA benchmark

Hi @ctlllll

How do I reproduce the results of the LRA benchmark presented in the paper? The GitHub repo only contains code for the SGConv block, please share the code for full network architecture and training code to reproduce the results.

I have already tried the code of the state-space-model and replaced the S4 block with the SGConv block and having a hard time reproducing the results.

Seeking details of the final SGConv model used for LRA results

Hi,

Thanks a lot for this wonderful work and for sharing it with the community.

I want to reproduce SGConv's results on LRA and have the following questions regarding that. Please help me with them.

  1. Could you please direct me to some references that you used for the final evaluation code (data processing, the evaluation metric computation, etc) so that I can replicate the complete setup?

  2. What are the values of different hyperparameters used for LRA tasks:
    (a) What are the values of inputs to init method - some are in the Appendix of the paper and you have mentioned some values in this notebook, but I am not sure if these were the ones used for LRA tasks. Could you please provide task-specific values of these hyperparameters for LRA?
    (b) How many GConv layers are there in the final model?
    (c) Init method's parameter list defines mode as a string with one value "cat_randn" (line 277). Is this the value used for all experiments?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.