Giter Club home page Giter Club logo

Comments (9)

ltkong218 avatar ltkong218 commented on June 15, 2024

According to my experience, provided correlation package only supports PyTorch 0.4.1. I install it under Ubuntu 16.04, and my gcc version is 5.4.0. I think you can try to search for other implementation of the correlation layer to replace it.

from fastflownet.

AliKafaei avatar AliKafaei commented on June 15, 2024

Thanks for your response. From ubuntu 18, gcc version 5.4 is not supported anymore (I have managed to run gcc version 5 before but that was very challenging). If the correlation package support newer version of gcc, that would widen the reproducibility of the work.

from fastflownet.

ltkong218 avatar ltkong218 commented on June 15, 2024

Please refer to Pytorch Correlation module. This module supports newer versions of PyTorch, such as 1.2 and so on.

from fastflownet.

pacifinapacific avatar pacifinapacific commented on June 15, 2024

Looking at the link below I was able to successfully run the following command: pip install spatial-correlation-sampler
How do I get your program to work?
https://github.com/ClementPinard/Pytorch-Correlation-extension

from fastflownet.

ltkong218 avatar ltkong218 commented on June 15, 2024
import torch
from spatial_correlation_sampler import SpatialCorrelationSampler

# define a correlation module
correlation_sampler = SpatialCorrelationSampler(1, 9, 1, 0, 1)

output = correlation_sampler(input1, input2)

# reshape output to be a 3D cost volume
b, c, h, w = input1.shape
output = output.view(b, -1, h, w) / c

from fastflownet.

ChuanchuanZheng avatar ChuanchuanZheng commented on June 15, 2024
import torch
from spatial_correlation_sampler import SpatialCorrelationSampler

# define a correlation module
correlation_sampler = SpatialCorrelationSampler(1, 9, 1, 0, 1)

output = correlation_sampler(input1, input2)

# reshape output to be a 3D cost volume
b, c, h, w = input1.shape
output = output.view(b, -1, h, w) / c

Hi. I replace the self.corr in the original code with this initialization of correlation_sampler. But the error occurs, which is RuntimeError: input1 must be contiguous. Could you tell me how to replace the origin corr code to get your code to work?

from fastflownet.

ltkong218 avatar ltkong218 commented on June 15, 2024
import torch
from spatial_correlation_sampler import SpatialCorrelationSampler

input1 = torch.randn(2, 32, 48, 64).cuda()
input2 = torch.randn(2, 32, 48, 64).cuda()

# define a correlation module
correlation_sampler = SpatialCorrelationSampler(1, 9, 1, 0, 1)

output = correlation_sampler(input1, input2)

# reshape output to be a 3D cost volume
b, c, h, w = input1.shape
output = output.view(b, -1, h, w) / c

print(output.shape)

I run above code and it's okay. So please check whether your input tensors are contiguous, or you can call .contiguous() to make them contiguous.

from fastflownet.

fransiskusyoga avatar fransiskusyoga commented on June 15, 2024

I dont know is it correct or not but what I did was.

OOO in models/correlation_package/setup.py
remove ", extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}"

OOO in /models/correlation_package/correlation_cuda.cc
change “at::globalContext().getCurrentCUDAStream()” to “ at::cuda::getCurrentCUDAStream()”
and I add "#include <ATen/cuda/CUDAContext.h>"

OOO edit /models/correlation_package/correlation.py

import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
import correlation_cuda

class CorrelationFunction(Function):

    
    @staticmethod
    def forward(ctx, input1, input2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1,):
        ctx.save_for_backward(input1, input2)
        ctx.corr_params = (pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply)
        #out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()
            output = input1.new()

            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 
                pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply)

        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input1, input2 = ctx.saved_tensors
        pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply = ctx.corr_params

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()

            grad_input1 = input1.new()
            grad_input2 = input2.new()

            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
                pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply)

        return grad_input1, grad_input2


class Correlation(Module):
    def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
        super(Correlation, self).__init__()
        self.pad_size = pad_size
        self.kernel_size = kernel_size
        self.max_displacement = max_displacement
        self.stride1 = stride1
        self.stride2 = stride2
        self.corr_multiply = corr_multiply

    def forward(self, input1, input2):

        result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return result

after those edit i can run benchmark

from fastflownet.

ltkong218 avatar ltkong218 commented on June 15, 2024

I suggest that you can set fixed input tensors and compare outputs from different implementations to check it.

from fastflownet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.