import zodiax
class Foo(zodiax.ExtendedBase):
bar : np.ndarray
def __init__(self, bar):
self.bar = np.array(bar)
foo = Foo([1., 2, 3])
optim, opt_state = foo.get_optimiser('bar', optax.adam(1.))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[119], line 10
8 foo = Foo([1., 2, 3])
9 # optim, opt_state = foo.get_optimiser(['bar'], [optax.adam(2e-8)])
---> 10 optim, opt_state = foo.get_optimiser('bar', optax.adam(2e-8))
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/zodiax-0.1.1-py3.11.egg/zodiax/base.py:691, in ExtendedBase.get_optimiser(self, paths, optimisers, get_args, pmap)
688 optim = multi_transform(opt_dict, param_spec)
690 # Get filtered optimiser
--> 691 opt_state = optim.init(eqx_filter(self, is_array))
693 return (optim, opt_state) if not get_args \
694 else (optim, opt_state, self.get_args(paths, pmap))
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/combine.py:135, in multi_transform.<locals>.init_fn(params)
130 if not label_set.issubset(transforms.keys()):
131 raise ValueError('Some parameters have no corresponding transformation.\n'
132 f'Parameter labels: {list(sorted(label_set))} \n'
133 f'Transforms keys: {list(sorted(transforms.keys()))} \n')
--> 135 inner_states = {
136 group: wrappers.masked(tx, make_mask(labels, group)).init(params)
137 for group, tx in transforms.items()
138 }
139 return MultiTransformState(inner_states)
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/combine.py:136, in <dictcomp>(.0)
130 if not label_set.issubset(transforms.keys()):
131 raise ValueError('Some parameters have no corresponding transformation.\n'
132 f'Parameter labels: {list(sorted(label_set))} \n'
133 f'Transforms keys: {list(sorted(transforms.keys()))} \n')
135 inner_states = {
--> 136 group: wrappers.masked(tx, make_mask(labels, group)).init(params)
137 for group, tx in transforms.items()
138 }
139 return MultiTransformState(inner_states)
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/wrappers.py:482, in masked.<locals>.init_fn(params)
480 def init_fn(params):
481 mask_tree = mask(params) if callable(mask) else mask
--> 482 masked_params = mask_pytree(params, mask_tree)
483 return MaskedState(inner_state=inner.init(masked_params))
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/wrappers.py:478, in masked.<locals>.mask_pytree(pytree, mask_tree)
477 def mask_pytree(pytree, mask_tree):
--> 478 return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree)
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/jax/_src/tree_util.py:206, in tree_map(f, tree, is_leaf, *rest)
173 """Maps a multi-input function over pytree args to produce a new pytree.
174
175 Args:
(...)
203 [[5, 7, 9], [6, 1, 2]]
204 """
205 leaves, treedef = tree_flatten(tree, is_leaf)
--> 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/jax/_src/tree_util.py:206, in <listcomp>(.0)
173 """Maps a multi-input function over pytree args to produce a new pytree.
174
175 Args:
(...)
203 [[5, 7, 9], [6, 1, 2]]
204 """
205 leaves, treedef = tree_flatten(tree, is_leaf)
--> 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Expected list, got Array([1., 2., 3.], dtype=float64).