tvayer / sgw Goto Github PK
View Code? Open in Web Editor NEWCode for Sliced Gromov-Wasserstein
Code for Sliced Gromov-Wasserstein
Dear Tituoan,
thanks for making this -- I'm adding FGW to make sliced FGW
for whatever reason, this fails right now:
s_samples, t_samples = 100, 200
n_projections = 5
dimensionality = 8
Xs=np.random.rand(s_samples, dimensionality)
Xt=np.random.rand(t_samples, dimensionality)
P=np.random.randn(dimensionality, n_projections)
%%time
sgw_cpu(Xs,Xt,P=P)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<timed eval> in <module>
~/hax/SGW/lib/sgw_numpy.py in sgw_cpu(xs, xt, nproj, tolog, P)
66 log['gw_1d_details']=log_gw1d
67 else:
---> 68 d=gromov_1d(xsp,xtp,tolog=False)
69
70 if tolog:
~/hax/SGW/lib/sgw_numpy.py in gromov_1d(xs, xt, tolog, fast)
170 xt_asc=np.sort(xt,axis=0)
171 xt_desc=np.sort(xt,axis=0)[::-1]
--> 172 l1=_cost(xs2,xt_asc,tolog=tolog)
173 l2=_cost(xs2,xt_desc,tolog=tolog)
174 toreturn=np.mean(np.minimum(l1,l2))
~/hax/SGW/lib/sgw_numpy.py in _cost(xsp, xtp, tolog)
109 Y4=np.sum(xt**4)
110
--> 111 xxyy_=np.sum((xs**2)*(xt**2))
112 xxy_=np.sum((xs**2)*(xt))
113 xyy_=np.sum((xs)*(xt**2))
ValueError: operands could not be broadcast together with shapes (100,) (200,)
xs=torch.from_numpy(Xs).to(torch.float32).to('cuda')
xt=torch.from_numpy(Xt).to(torch.float32).to('cuda')
device = 'cuda'
%%time
sgw_gpu(xs,xt,device,P=torch.from_numpy(P).to(torch.float32))
----------------------------------------
xs origi dim : torch.Size([100, 8])
xt origi dim : torch.Size([200, 8])
dim_p : 8
dim_d : 8
random_projection_dim : 8
projector dimension : torch.Size([2, 5])
xs2 dim : torch.Size([100, 8])
xt2 dim : torch.Size([200, 8])
xs_tmp dim : torch.Size([100, 8])
xt_tmp dim : torch.Size([200, 8])
----------------------------------------
size mismatch, m1: [100 x 8], m2: [2 x 5] at /opt/conda/conda-bld/pytorch_1587428266983/work/aten/src/THC/generic/THCTensorMathBlas.cu:283
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~/hax/SGW/lib/sgw_pytorch.py in sink_(xs, xt, device, nproj, P)
242 try:
--> 243
244 xsp = torch.matmul(xs2, p.to(device))
RuntimeError: size mismatch, m1: [100 x 8], m2: [2 x 5] at /opt/conda/conda-bld/pytorch_1587428266983/work/aten/src/THC/generic/THCTensorMathBlas.cu:283
During handling of the above exception, another exception occurred:
BadShapeError Traceback (most recent call last)
<timed eval> in <module>
~/hax/SGW/lib/sgw_pytorch.py in sgw_gpu(xs, xt, device, nproj, tolog, P)
61 xsp, xtp = sink_(xs, xt, device, nproj, P)
62 ed = time.time()
---> 63 log['time_sink_'] = ed-st
64 else:
65 xsp, xtp = sink_(xs, xt, device, nproj, P)
~/hax/SGW/lib/sgw_pytorch.py in sink_(xs, xt, device, nproj, P)
257 print('xt_tmp dim :', xt2.shape)
258 print('----------------------------------------')
--> 259 print(error)
260 raise BadShapeError
261
BadShapeError:
Assume I have 2 tensor with a shape of Batch_size, number_of_points, channels.
How can I calculate the SGW as a batch?
For example, 5 point clouds and another 5 point clouds.
Is there a simple way to calculate result in a batch?
Would it be possible to add an MIT or Apache2 license?
Thanks for your work ... Iโm curious how this can be applied to unbalanced organic chemistry reactions and protein mechanics. Would this be permutation invariant for atoms of the same element? (We use Two unbalanced arrays of shape (N, 16) where 16 is xyz and mass, charge, element etc)
I have a question to ask: is it possible to use SGW in GANs using DCGAN architecture by minimizing the SGW loss?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.