Giter Club home page Giter Club logo

computation-thru-dynamics's Introduction

Computation Through Dynamics

This repository contains a number of subprojects related to the interlinking of computation and dynamics in artificial and biological neural systems.

This is not an officially supported Google product.

Prerequisites

The code is written to be compatible with Python 3. You will also need:

  • JAX version 0.1.75 or greater (install) -
  • JAX lib latest version (installed with JAX)
  • NumPy, SciPy, Matplotlib (install SciPy stack, contains all of them)
  • h5py (install)
  • A GPU - XLA compiles these examples to CPU very slowly, so best to use a GPU for now.

Analysis of toy model associated with How recurrent networks implement contextual processing in sentiment analysis

Neural networks have a remarkable capacity for contextual processing—using recent or nearby inputs to modify processing of current input. For example, in natural language, contextual processing is necessary to correctly interpret negation (e.g. phrases such as “not bad”). However, our ability to understand how networks process context is limited. Here, we propose general methods for reverse engineering recurrent neural networks (RNNs) to identify and elucidate contextual processing.

This Jupyter notebook runs through the analysis of the toy model found in How recurrent networks implement contextual processing in sentiment analysis.

LFADS - Latent Factor Analysis via Dynamical Systems

LFADS is a tool for inferring dynamics from noisy, high-dimensional observations of a dynamics system. It is a sequential auto-encoder with some very particular bells and whistles. Here we have released a tutorial version, written in Python / Numpy / JAX intentionally implemented with readabilily, comprehension and innovation in mind. You may find the full TensorFlow implementation with run manager support (here).

The LFADS tutorial uses the integrator RNN example (see below). The LFADS tutorial example attempts to infer the hidden states of the integrator RNN as well as the white noise input to the RNN. One runs the integrator RNN example and then copies the resulting data file, written in /tmp/ to /tmp/LFADS/data/. Edit the name of the data file in run_lfads.py and then run execute run_lfads.py.

The LFADS tutorial is run through this Jupyter notebook.

Integrator RNN - train a Vanilla RNN to integrate white noise.

Integration is a very simple task and highlights how to set up a loop over time, batch over multiple input/target examples, use just-in-time compilation to speed the computation up, and take a gradient in JAX. The data from this example is also used as input for the LFADS tutorial.

This example is run through this Jupyter notebook.

Fixed point finding - train a GRU to make a binary decision and study it via fixed point finding.

The goal of this tutorial is to learn about fixed point finding by running the algorithm on a Gated Recurrent Unit (GRU), which is trained to make a binary decision, namely whether the integral of the white noise input is in total positive or negative, outputing either a +1 or a -1 to encode the decision.

Running the fixed point finder on this decision-making GRU will yield:

  1. the underlying fixed points
  2. the first order taylor series approximations around those fixed points.

Going through this tutorial will exercise the concepts defined in the Opening the black box: low-dimensional dynamics in high-dimensional recurrent neural networks.

This example is run through this Jupyter notebook.

FORCE learning in Echostate networks

In Colab, Train an echostate network (ESN) to generate the chaotic output of another recurrent neural network. This Colab / IPython notebook implements a continuous-time ESN with FORCE learning implemented via recursive least squares (RLS). It also lets you use a GPU and quickly get started with JAX! Two different implementations are explored, one at the JAX / Python level and another at the LAX level. After JIT compilation, the JAX implementation runs very fast.

computation-thru-dynamics's People

Contributors

mattjj avatar maxpatiiuk avatar youngju-jo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

computation-thru-dynamics's Issues

Couldn't reproduce TSNE result of LFADS inferred initial generator state

I used the integrator RNN tutorial to produce the input results for LFADS tutorial.
All parameters in integrator RNN and LFADS tutorial are not changed.
But when the last step in using TSNE visualization to produce LFADS inferred initial generator state, I couldn't get the shape like the one in LFADS tutorial code(like a circular, concentrated line). Instead, no matter how many times I tried running the tutorial, my result shows a spread, non-structural shape like this below image:

Screenshot from 2021-11-05 15-14-32

Does anyone also have this problem?

my packages version:
ubuntu 18.04
jax 0.2.20
jaxlib 0.1.71+cuda111
scikit-learn 0.24.2

Undefined names

flake8 testing of https://github.com/google-research/computation-thru-dynamics on Python 3.7.1

$ flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics

./integrator_rnn_tutorial/run_integrator_rnn.py:76:3: F821 undefined name 'plot_batch'
  plot_batch(inputs, targets)
  ^
./integrator_rnn_tutorial/integrator.py:60:16: F821 undefined name 'ntimesteps'
  plt.xlim([0, ntimesteps-1])
               ^
./integrator_rnn_tutorial/integrator.py:65:18: F821 undefined name 'ntimesteps'
    plt.xlim([0, ntimesteps-1]);
                 ^
./integrator_rnn_tutorial/integrator.py:68:18: F821 undefined name 'ntimesteps'
    plt.xlim([0, ntimesteps-1]);
                 ^
./integrator_rnn_tutorial/integrator.py:73:18: F821 undefined name 'ntimesteps'
    plt.xlim([0, ntimesteps-1]);
                 ^
5     F821 undefined name 'ntimesteps'
5

E901,E999,F821,F822,F823 are the "showstopper" flake8 issues that can halt the runtime with a SyntaxError, NameError, etc. These 5 are different from most other flake8 issues which are merely "style violations" -- useful for readability but they do not effect runtime safety.

  • F821: undefined name name
  • F822: undefined name name in __all__
  • F823: local variable name referenced before assignment
  • E901: SyntaxError or IndentationError
  • E999: SyntaxError -- failed to compile a file into an Abstract Syntax Tree

JSON error in Integrator RNN Tutorial

Hello,

Thank you all for making this, I am very excited about applying these methods.

I just wanted to report a small error formatting error in the Integrator RNN Tutorial.ipynb file. This file is missing a single comma between lines 55 and 56 that threw a JSON error for me when I first downloaded it. I underline this missing comma below:

"import sys\n", "import time\n" _,_ "from importlib import reload" ]

Best wishes and thanks again,
Josh

Error in Integrator RNN Tutorial - Dimension mismatch in `affine`

Hi there,
Thanks very much for putting these tutorials together. Unfortunately, I can't seem to get through the Integrator RNN Tutorial (prerequisite to LFADS tutorial).

Here is the error:
dot_general requires lhs batch dimensions and rhs batch dimensions to have the same shape, got [128] and [25].

The error occurs in the "Train the VRNN" section in the following command: opt_state = rnn.update_w_gc_jit(batch, opt_state, opt_update, get_params, inputs, targets, max_grad_norm, l2reg).

Digging deeper in PyCharm's debugger, the lowest-level call from this repo (before it gets passed off to jax) that causes this error is the jax dot-product on this line:
np.dot(params['wO'], x)

Where params is
{'bO': Traced<ShapedArray(float32[1])>with<JVPTrace(level=1/1)>, 'bR': Traced<ShapedArray(float32[100])>with<JVPTrace(level=1/1)>, 'h0': Traced<ShapedArray(float32[100])>with<JVPTrace(level=1/1)>, 'wI': Traced<ShapedArray(float32[100,1])>with<JVPTrace(level=1/1)>, 'wO': Traced<ShapedArray(float32[1,100])>with<JVPTrace(level=1/1)>, 'wR': Traced<ShapedArray(float32[100,100])>with<JVPTrace(level=1/1)>}

And x is Traced<ShapedArray(float32[100])>with<BatchTrace(level=5/1)>

From these arguments, I have no idea where array shapes 128 and 25 are coming from. I'm new to JAX so maybe this is easy to explain, but it's pretty confusing to me.

I get the same error whether running on Google Colab Py2, Py3, or on my local Ubuntu 18.04 machine with Python 3.6, jax 0.1.36, and jaxlib 0.1.18.


Side Note: To get this notebook to run on Google Colab, enable the GPU in Notebook Settings, then run the following in a cell at the top:

import os
import sys
if (sys.version_info > (3, 0)):
    from importlib import reload
    !pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-latest-cp36-none-linux_x86_64.whl
else:
    !pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-latest-cp27-none-linux_x86_64.whl
!pip install --upgrade -q jax
if not os.path.isdir('computation-thru-dynamics'):
    !git clone --recursive https://github.com/google-research/computation-thru-dynamics.git
sys.path.append(os.path.join('.', 'computation-thru-dynamics'))

LFADS Tutorial running error

Hi there,

thanks for this amazing tutorial. I was trying to run through it but I faced an error like below:


ValueError Traceback (most recent call last)
in ()
1 key = random.PRNGKey(onp.random.randint(0, utils.MAX_SEED_INT))
----> 2 trained_params, opt_details,_ = optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps,train_data, eval_data)
3 # trained_params, opt_details, _ =
4 # optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps,
5 # train_data_fun, eval_data_fun)

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in optimize_lfads(key, init_params, lfads_hps, lfads_opt_hps, train_data, eval_data)
159 print_every, update_w_gc, kl_warmup_fun,
160 opt_state, lfads_hps, lfads_opt_hps,
--> 161 train_data)
162 batch_time = time.time() - start_time
163

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in f_jitted(*args, **kwargs)
148 _check_args(args_flat)
149 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
151 return tree_unflatten(out_tree(), out)
152

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/core.pyc in call_bind(primitive, f, *args, **params)
596 if top_trace is None:
597 with new_sublevel():
--> 598 outs = primitive.impl(f, *args, **params)
599 else:
600 tracers = map(top_trace.full_raise, args)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/xla.pyc in _xla_call_impl(fun, *args, **params)
440 device = params['device']
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
444 return compiled_fun(*args)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in memoized_fun(fun, *args)
208 fun.populate_stores(stores)
209 else:
--> 210 ans = call(fun, *args)
211 cache[key] = (ans, fun.stores)
212 return ans

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/xla.pyc in _xla_callable(fun, device, backend, *arg_specs)
457 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
458 with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
460 assert not env # no subtraces here
461 del master, env

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in optimize_lfads_core(key, batch_idx_start, num_batches, update_fun, kl_warmup_fun, opt_state, lfads_hps, lfads_opt_hps, train_data)
97 lower = batch_idx_start
98 upper = batch_idx_start + num_batches
---> 99 return lax.fori_loop(lower, upper, run_update, opt_state)
100
101

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in fori_loop(lower, upper, body_fun, init_val)
141 raise TypeError(msg.format(lower_dtype.name, upper_dtype.name))
142 _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun),
--> 143 (lower, upper, init_val))
144 return result
145

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in while_loop(cond_fun, body_fun, init_val)
192 init_avals = tuple(_map(_abstractify, init_vals))
193 cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
--> 194 body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
195 if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
196 msg = "cond_fun must return a boolean scalar, but got pytree {}."

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in _initial_style_jaxpr(fun, in_tree, in_avals)
59 fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
60 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
---> 61 stage_out_calls=True)
62 out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
63 const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/partial_eval.pyc in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls)
331 with new_master(trace_type) as master:
332 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 333 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
334 assert not env
335 del master

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/lax/lax_control_flow.pyc in while_body_fun(loop_carry)
93 def while_body_fun(loop_carry):
94 i, upper, x = loop_carry
---> 95 return lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)
96 return while_body_fun
97

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in run_update(batch_idx, opt_state)
92 x_bxt = train_data[didxs].astype(np.float32)
93 opt_state = update_fun(batch_idx, opt_state, lfads_hps, lfads_opt_hps,
---> 94 next(fkeyg), x_bxt, kl_warmup)
95 return opt_state
96

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/optimize.pyc in update_w_gc(i, opt_state, lfads_hps, lfads_opt_hps, key, x_bxt, kl_warmup)
140 grads = grad(lfads.lfads_training_loss)(params, lfads_hps, key, x_bxt,
141 kl_warmup,
--> 142 lfads_opt_hps['keep_rate'])
143 clipped_grads = optimizers.clip_grads(grads, lfads_opt_hps['max_grad_norm'])
144 return opt_update(i, clipped_grads, opt_state)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in grad_f(*args, **kwargs)
353 @wraps(fun, docstr=docstr, argnums=argnums)
354 def grad_f(*args, **kwargs):
--> 355 _, g = value_and_grad_f(*args, **kwargs)
356 return g
357

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in value_and_grad_f(*args, **kwargs)
408 f_partial, dyn_args = _argnums_partial(f, argnums, args)
409 if not has_aux:
--> 410 ans, vjp_py = vjp(f_partial, *dyn_args)
411 else:
412 ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in vjp(fun, *primals, **kwargs)
1270 if not has_aux:
1271 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1272 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1273 out_tree = out_tree()
1274 else:

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/ad.pyc in vjp(traceable, primals, has_aux)
106 def vjp(traceable, primals, has_aux=False):
107 if not has_aux:
--> 108 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
109 else:
110 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/ad.pyc in linearize(traceable, *primals, **kwargs)
95 _, in_tree = tree_flatten(((primals, primals), {}))
96 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
98 pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
99 aval_primals, const_primals = unzip2(pval_primals)

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/partial_eval.pyc in trace_to_jaxpr(fun, pvals, instantiate, stage_out_calls)
331 with new_master(trace_type) as master:
332 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 333 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
334 assert not env
335 del master

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_training_loss(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
540 return the total loss for optimization
541 """
--> 542 losses = lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
543 return losses['total']
544

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_losses(params, lfads_hps, key, x_bxt, kl_scale, keep_rate)
488 key, skeys = utils.keygen(key, 2)
489 keys_b = random.split(next(skeys), B)
--> 490 lfads = batch_lfads(params, lfads_hps, keys_b, x_bxt, keep_rate)
491
492 # Sum over time and state dims, average over batch.

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/api.pyc in batched_fun(*args)
693 _check_axis_sizes(in_tree, args_flat, in_axes_flat)
694 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 695 lambda: _flatten_axes(out_tree(), out_axes))
696 return tree_unflatten(out_tree(), out_flat)
697

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/batching.pyc in batch(fun, in_vals, in_dims, out_dim_dests)
41 def batch(fun, in_vals, in_dims, out_dim_dests):
42 size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
---> 43 out_vals, out_dims = batch_fun(fun, in_vals, in_dims)
44 return map(partial(matchaxis, size), out_dims, out_dim_dests(), out_vals)
45

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/interpreters/batching.pyc in batch_fun(fun, in_vals, in_dims)
47 with new_master(BatchTrace) as master:
48 fun, out_dims = batch_subtrace(fun, master, in_dims)
---> 49 out_vals = fun.call_wrapped(*in_vals)
50 del master
51 return out_vals, out_dims()

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/linear_util.pyc in call_wrapped(self, *args, **kwargs)
151 gen = None
152
--> 153 ans = self.f(*args, **dict(self.params, **kwargs))
154 del args
155 while stack:

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads(params, lfads_hps, key, x_t, keep_rate)
447
448 ic_mean, ic_logvar, xenc_t =
--> 449 lfads_encode(params, lfads_hps, next(skeys), x_t, keep_rate)
450
451 c_t, gen_t, factor_t, ii_t, ii_mean_t, ii_logvar_t, lograte_t = \

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in lfads_encode(params, lfads_hps, key, x_t, keep_rate)
320 x_t = run_dropout(x_t, next(skeys), keep_rate)
321 con_ins_t, gen_pre_ics = run_bidirectional_rnn(params['ic_enc'], gru, gru,
--> 322 x_t)
323 # Push through to posterior mean and variance for initial conditions.
324 xenc_t = dropout(con_ins_t, next(skeys), keep_rate)

/Users/mohsen/Documents/Ylab/Projects/Repos/computation-thru-dynamics/lfads_tutorial/lfads.pyc in run_bidirectional_rnn(params, fwd_rnn, bwd_rnn, x_t)
254 params['bwd_rnn']['h0']))
255 full_enc = np.concatenate([fwd_enc_t, bwd_enc_t], axis=1)
--> 256 enc_ends = np.concatenate([bwd_enc_t[0], fwd_enc_t[-1]], axis=1)
257 return full_enc, enc_ends
258

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in concatenate(arrays, axis)
1637 if ndim(arrays[0]) == 0:
1638 raise ValueError("Zero-dimensional arrays cannot be concatenated.")
-> 1639 axis = _canonicalize_axis(axis, ndim(arrays[0]))
1640 arrays = _promote_dtypes(*arrays)
1641 # lax.concatenate can be slow to compile for wide concatenations, so form a

/Users/mohsen/anaconda3/envs/jenv/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _canonicalize_axis(axis, num_dims)
378 raise ValueError(
379 "axis {} is out of bounds for array of dimension {}".format(
--> 380 axis, num_dims))
381 return axis
382

ValueError: axis 1 is out of bounds for array of dimension 1

Could you please help me rectify this error?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.