Hi there,
thanks for this amazing tutorial. I was trying to run through it but I faced an error like below:
ValueError Traceback (most recent call last)
in ()
1 key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
----> 2 trained_params, opt_details,_ = optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps,train_data, eval_data)
3 # trained_params, opt_details, _ =
4 # optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps,
5 # train_data_fun, eval_data_fun)
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps, train_data, eval_data)
159 print_every, update_w_gc, kl_warmup_fun,
160 opt_state, lfads_hps, lfads_opt_hps,
--> 161 train_data)
162 batch_time = time.time() - start_time
163
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in f_jitted(*args, **kwargs)
148 _check_args(args_flat)
149 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
151 return tree_unflatten(out_tree(), out)
152
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/core.pyc in call_bind(primitive, f, *args, **params)
596 if top_trace is None:
597 with new_sublevel():
--> 598 outs = primitive.impl(f, *args, **params)
599 else:
600 tracers = map(top_trace.full_raise, args)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/xla.pyc in _xla_call_impl(fun, *args, **params)
440 device = params['device']
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
444 return compiled_fun(*args)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in memoized_fun(fun, *args)
208 fun.populate_stores(stores)
209 else:
--> 210 ans = call(fun, *args)
211 cache[key] = (ans, fun.stores)
212 return ans
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/xla.pyc in _xla_callable(fun, device, backend, *arg_specs)
457 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
458 with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
460 assert not env # no subtraces here
461 del master, env
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in optimize_lfads_core(key, batch_idx_start, num_batches, update_fun, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data)
97 lower = batch_idx_start
98 upper = batch_idx_start + num_batches
---> 99 return lax.fori_loop(lower, upper, run_update, opt_state)
100
101
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in fori_loop(lower, upper, body_fun, init_val)
141 raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
142 _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
--> 143 (lower, upper, init_val))
144 return result
145
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in while_loop(cond_fun, body_fun, init_val)
192 init_avals = tuple(_map(_abstractify, init_vals))
193 cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 194 body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
195 if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
196 msg = "cond_fun must return a boolean scalar, but got pytree {}."
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in _initial_style_jaxpr(fun, in_tree, in_avals)
59 fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
60 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
---> 61 stage_out_calls=True)
62 out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
63 const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/partial_eval.pyc in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls)
331 with new_master(trace_type) as master:
332 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 333 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
334 assert not env
335 del master
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in while_body_fun(loop_carry)
93 def while_body_fun(loop_carry):
94 i, upper, x = loop_carry
---> 95 return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
96 return while_body_fun
97
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in run_update(batch_idx, opt_state)
92 x_bxt = train_data[didxs].astype(np.float32)
93 opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps,
---> 94 next(fkeyg), x_bxt, kl_warmup)
95 return opt_state
96
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup)
140 grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt,
141 kl_warmup,
--> 142 lfads_opt_hps['keep_rate'])
143 clipped_grads = optimizers.clip_grads(grads, lfads_opt_hps['max_grad_norm'])
144 return opt_update(i, clipped_grads, opt_state)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in grad_f(*args, **kwargs)
353 @wraps(fun, docstr=docstr, argnums=argnums)
354 def grad_f(*args, **kwargs):
--> 355 _, g = value_and_grad_f(*args, **kwargs)
356 return g
357
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in value_and_grad_f(*args, **kwargs)
408 f_partial, dyn_args = _argnums_partial(f, argnums, args)
409 if not has_aux:
--> 410 ans, vjp_py = vjp(f_partial, *dyn_args)
411 else:
412 ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in vjp(fun, *primals, **kwargs)
1270 if not has_aux:
1271 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1272 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1273 out_tree = out_tree()
1274 else:
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/ad.pyc in vjp(traceable, primals, has_aux)
106 def vjp(traceable, primals, has_aux=False):
107 if not has_aux:
--> 108 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
109 else:
110 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/ad.pyc in linearize(traceable, *primals, **kwargs)
95 _, in_tree = tree_flatten(((primals, primals), {}))
96 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
98 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
99 aval_primals, const_primals = unzip2(pval_primals)
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/partial_eval.pyc in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls)
331 with new_master(trace_type) as master:
332 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 333 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
334 assert not env
335 del master
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_training_loss(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
540 return the total loss for optimization
541 """
--> 542 losses = lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
543 return losses['total']
544
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
488 key, skeys = utils.keygen(key, 2)
489 keys_b = random.split(next(skeys), B)
--> 490 lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate)
491
492 # Sum over time and state dims, average over batch.
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in batched_fun(*args)
693 _check_axis_sizes(in_tree, args_flat, in_axes_flat)
694 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 695 lambda: _flatten_axes(out_tree(), out_axes))
696 return tree_unflatten(out_tree(), out_flat)
697
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/batching.pyc in batch(fun, in_vals, in_dims, out_dim_dests)
41 def batch(fun, in_vals, in_dims, out_dim_dests):
42 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 43 out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
44 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
45
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/batching.pyc in batch_fun(fun, in_vals, in_dims)
47 with new_master(BatchTrace) as master:
48 fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 49 out_vals = fun.call_wrapped(*in_vals)
50 del master
51 return out_vals, out_dims()
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads(params, lfads_hps, key, x_t, keep_rate)
447
448 ic_mean, ic_logvar, xenc_t =
--> 449 lfads_encode(params, lfads_hps, next(skeys), x_t, keep_rate)
450
451 c_t, gen_t, factor_t, ii_t, ii_mean_t, ii_logvar_t, lograte_t = \
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_encode(params, lfads_hps, key, x_t, keep_rate)
320 x_t = run_dropout(x_t, next(skeys), keep_rate)
321 con_ins_t, gen_pre_ics = run_bidirectional_rnn(params['ic_enc'], gru, gru,
--> 322 x_t)
323 # Push through to posterior mean and variance for initial conditions.
324 xenc_t = dropout(con_ins_t, next(skeys), keep_rate)
/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in run_bidirectional_rnn(params, fwd_rnn, bwd_rnn, x_t)
254 params['bwd_rnn']['h0']))
255 full_enc = np.concatenate([fwd_enc_t, bwd_enc_t], axis=1)
--> 256 enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1)
257 return full_enc, enc_ends
258
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in concatenate(arrays, axis)
1637 if ndim(arrays[0]) == 0:
1638 raise ValueError("Zero-dimensional arrays cannot be concatenated.")
-> 1639 axis = _canonicalize_axis(axis, ndim(arrays[0]))
1640 arrays = _promote_dtypes(*arrays)
1641 # lax.concatenate can be slow to compile for wide concatenations, so form a
/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _canonicalize_axis(axis, num_dims)
378 raise ValueError(
379 "axis {} is out of bounds for array of dimension {}".format(
--> 380 axis, num_dims))
381 return axis
382
ValueError: axis 1 is out of bounds for array of dimension 1
Could you please help me rectify this error?
Thanks