Comments (1)
Thank you for the issue!
Analysis
Here is the minimally reproducing example:
import torch
from torch import nn
from cplxmodule import cplx
import cplxmodule.nn as cplxnn
from torch.nn.parallel.data_parallel import data_parallel
net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
data_parallel
uses three key functions:
from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply
scatter_kwargs
is responsible for splitting the input along the batch dimension and moving the shards to appropriate devices. replicate
is responsible for performing in-vivo surgery on the model: taking it apart, and rebuilding it with its parameters wrapped in scatter-gather operations. Finally parallel_apply
takes properly placed inputs and model replicas and run them in parallel threads, collecting the output upon termination.
replication and moving inputs manually seems to work ok:
from torch.nn.parallel.replicate import replicate
net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))
replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas])
Since the error message from parallel_apply
coincides with
import torch.nn.functional as F
F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1))
I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.
scatter_kwargs
calls scatter on the input tensor. scatter uses some internal low-level functionality to scatter the tensor along the zero-th dimension. Unfortunately Cplx is a high level Python object and is duck-typed to behave like a Tensor only on high-level. It is not binary compatible with torch.Tensor
on C++ level.
Solution
Unfortunately I can only suggest a workaround wrapper.
def r2r_wrap(model, dim_in=1, dim_out=1):
return torch.nn.Sequential(
# convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
cplxnn.RealToCplx(dim=dim_in),
model,
# convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
cplxnn.CplxToReal(dim=dim_out)
)
model_complex = r2r_wrap(cplxtest_net()).to(device)
This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue.
The key nuance is that your input to and output of the model is just Tensor
, not cplx.Cplx
# put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)
# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2])
The output:
tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2])
These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself.
from cplxmodule.
Related Issues (19)
- cplx_trabelsi_independent_ not working for CplxLinear layer HOT 3
- CplxConv1d can be exported to ONNX but cannot be inferred by ONNXRUNTIME HOT 3
- SpectralNorm or WeightNorm
- Complex max pooling
- Complex max pooling HOT 3
- deepcopy doesn't work with cplxmodule modules HOT 3
- Pytorch Complex Autograd used in cplxmodule? HOT 2
- Complex Backprop and Learning speed HOT 7
- Feature suggestion: naive convolution : gauss trick HOT 2
- Torchscript support? HOT 6
- ModReLU : bug? HOT 5
- Useful papers
- Nit picks / Bugs HOT 2
- Any advice for using with Captum?
- Feature request : compatibility with einops HOT 2
- weight initialization issue HOT 4
- Implement Transposed Convolution
- cplxmodule.nn.CplxBatchNorm1d is not ONNX exportable HOT 7
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cplxmodule.