Hi Yang,
I'm getting the following error:
File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/losses.py", line 111, in loss_fn
losses = jnp.square(batch_mul(score, std) + z)
File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/utils.py", line 42, in batch_mul
return jax.vmap(lambda a, b: a * b)(a, b)
TypeError: can't multiply sequence by non-int of type 'BatchTracer'
If I use chex.fake_pmap to be able to print inside the pmap, I see:
std:
Traced<ShapedArray(float32[8])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[ 0.21264698, 0.77755433, 0.27918625, 0.9448618 ,
10.666621 , 0.24025024, 12.233008 , 3.626547 ]], dtype=float32)
batch_dim = 0
z:
Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[-2.01293421e+00, -2.17641640e+00, -1.23569024e+00],
[ 6.13737464e-01, 1.50414258e-01, -2.59380966e-01],
....
score:
(Traced<ShapedArray(float32[8,32,32,3])>with<JVPTrace(level=3/0)>
with primal = Traced<ShapedArray(float32[8,32,32,3])>with<BatchTrace(level=1/0)>
with val = DeviceArray([[[[[ 8.45328856e-08, -1.25030866e-07, -8.07002252e-08],
...
I tried to match your versions of libraries as much as possible. Same jax, flax, jaxlib version. Tensorflow_gpu=2.4.1. I use one GPU and config="configs/ncsnpp/cifar10_continuous_ve.py".
Update: If I do it on cifar10_continuous_vp.py it breaks here:
File "/localscratch/jolicoea.62752629.0/1/ScoreSDEMore/score_sde/models/utils.py", line 197, in score_fn
score = batch_mul(-model, 1. / std)
TypeError: bad operand type for unary -: 'tuple'