Giter Club home page Giter Club logo

Comments (1)

ivannz avatar ivannz commented on June 10, 2024

Thank you for the issue!

Analysis

Here is the minimally reproducing example:

import torch
from torch import nn

from cplxmodule import cplx
import cplxmodule.nn as cplxnn

from torch.nn.parallel.data_parallel import data_parallel


net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

data_parallel uses three key functions:

from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply

scatter_kwargs is responsible for splitting the input along the batch dimension and moving the shards to appropriate devices. replicate is responsible for performing in-vivo surgery on the model: taking it apart, and rebuilding it with its parameters wrapped in scatter-gather operations. Finally parallel_apply takes properly placed inputs and model replicas and run them in parallel threads, collecting the output upon termination.

replication and moving inputs manually seems to work ok:

from torch.nn.parallel.replicate import replicate

net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))

replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas])

Since the error message from parallel_apply coincides with

import torch.nn.functional as F

F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1))

I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.

scatter_kwargs calls scatter on the input tensor. scatter uses some internal low-level functionality to scatter the tensor along the zero-th dimension. Unfortunately Cplx is a high level Python object and is duck-typed to behave like a Tensor only on high-level. It is not binary compatible with torch.Tensor on C++ level.

Solution

Unfortunately I can only suggest a workaround wrapper.

def r2r_wrap(model, dim_in=1, dim_out=1):
    return torch.nn.Sequential(
        # convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
        cplxnn.RealToCplx(dim=dim_in),
        model,
        # convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
        cplxnn.CplxToReal(dim=dim_out) 
    )

model_complex = r2r_wrap(cplxtest_net()).to(device)

This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue.

The key nuance is that your input to and output of the model is just Tensor, not cplx.Cplx

# put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)

# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2])

The output:

tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2])

These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself.

from cplxmodule.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.