google-deepmind / dm-haiku Goto Github PK
View Code? Open in Web Editor NEWJAX-based neural network library
Home Page: https://dm-haiku.readthedocs.io
License: Apache License 2.0
JAX-based neural network library
Home Page: https://dm-haiku.readthedocs.io
License: Apache License 2.0
Hi there,
I'd like to ask for some advice.
Let's say I've got a function approximator for a q-function that uses hk.{get,set}_state()
along with hk.transform_with_state()
. This means that my function approximator consists of a triplet func, params, state
.
I would like to keep a separate copy of this function approximator, i.e. a target network. This means that I keep separate copy of the params
.
Now my question is, would you recommend I also keep a separate copy of the state
? And if so, how do we ensure that the usual smooth updates make sense? (e.g. variance typically can't be updated this way unless you map it to the 2nd moment first)
Hey, I created this framework called Elegy with a couple of friends, its Keras-like framework based on Haiku. The experience of creating a higher-level API for training using Haiku has been interesting, we've had to wrap some of Haiku's modules and functions to enable additional features but the overall experience has been very positive! It would be great if some of these features eventually went back to Haiku in the future if you are interested.
https://poets-ai.github.io/elegy/
Feel free to close the issue.
import haiku as hk
class MyModule(hk.Module):
def __init__(self):
super().__init__(name="MyModule")
print("MyModule.__init__ ran")
def __call__(self, x):
print("MyModule.__call__ ran")
return x + 1
def forward(x):
return MyModule()(x)
rng = hk.PRNGSequence(420)
fwd = hk.transform(forward)
params = fwd.init(next(rng), 1)
x = 0
for _ in range(10):
x = fwd.apply(params, x)
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
is this intended? It caused my module to break eventually because it pulled different hyperparameters at different invocations
Do you want to hand tune your models? Most folks don’t because it’s slow. So we get into hyperparameter optimization.
In Haiku, we define a model like
def forward(x):
return MyModule(hyperparameters)(x)
model = hk.transform(forward)
Where hyperparameters is a dict or class to record decisions modules make when they “pull” choices like “number of layers” “latent code dimensionality” “block 2 n_heads” ... we cannot enumerate all possible hyperparameters ahead of time because certain ones depend on the value of earlier ones and there are just too many combinations
Those hyperparameters become “stuck” inside the forward function, unless we make hyperparameters a mutable datastructure and mutate it inside all our MyModule.init methods. Then, the forward function has deeply nested side effects when defined, so you cannot share the decisions across a group of agents because they will mutate shared state, and the haiku tracer needs to know to ignore these init hyperparameter mutations, which I’m assuming it does but not 100% sure. Mutating hyperparameters at definition time is also confusing for engineers because there’s many “init” steps at different times
How can we make an elegant way for “forward” methods to pull hyperparameters and return their decisions?
Right now colab uses jax 0.1.69 and so it fails to import haiku.
I believe it would be better to add both jax and jaxlib as normal requirements, so when someone installs they get the correct versions of the dependencies.
Hi there! I've been trying to get familiar with the library by running some examples in the examples/
folder. My environment was set up according to the instructions on https://github.com/google/jax#installation
and https://github.com/deepmind/dm-haiku#installation
.
On running the mnist.py
example with TensorFlow 2.1.0, a Fatal Python error: Aborted
occurs. The full error message is as below:
2020-05-11 20:08:28.772679: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772772: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0511 20:08:29.945827 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:29.947907 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:29.948163 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:29.948284 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.869276 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.870402 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.870607 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.870710 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.915239 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.916428 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.916637 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.916740 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split test, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
2020-05-11 20:08:32.182307: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:236] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2020-05-11 20:08:32.182349: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms)
Fatal Python error: Aborted
Current thread 0x00007f712fd8f740 (most recent call first):
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 156 in compile
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 576 in Compile
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 197 in xla_primitive_callable
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 166 in apply_primitive
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/core.py", line 199 in bind
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 626 in dot_general
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 564 in dot
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2484 in dot
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 161 in __call__
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 120 in __call__
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
File "mnist.py", line 41 in net_fn
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 271 in init_fn
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 106 in init_fn
File "mnist.py", line 112 in main
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 250 in _run_main
File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 299 in run
File "mnist.py", line 131 in <module>
Aborted (core dumped)
One solution I've found to this is a more commonplace solution when using TensorFlow, by inserting the code:
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7
sess = InteractiveSession(config=config)
However, this kind of defeats the purpose if one is simply trying to use JAX/NumPy instead of TensorFlow. Not sure what else I can provide to help, please do let me know!
Is there a plan of adding probabilistic programming interface, such as rlax.distributions
or tfp
.
Right now, numpyro
is the only JAX-powered library for this. I could add an example of how to combine dm-haiku
and numpyro
if you want.
Hey- This is Chris.
I'm using this open-source for my project.
https://github.com/chris-chris/haiku-scalable-example
Since I'm new to JAX and haiku, I have some questions.
Is there a good way to save/load & compress/decompress & serialize model weights?
I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?
Thanks!
There can be some strange behavior in frozendict due to attr assignment:
import haiku as hk
frozen = hk.data_structures.to_immutable_dict({'a': 'foo'})
print(frozen['a'], frozen.a) # --> foo foo
frozen.a = 'bar' # should raise error
print(frozen['a'], frozen.a) # --> foo bar
My suggestion would be to simply disallow any setattr
on frozendict
.
Also, as a side note, please expose frozendict
to the public api. Right now, the only way to access it is through haiku._src.data_structures.frozendict
.
We should add the equivalent of this test:
https://github.com/deepmind/dm-haiku/blob/4bcbec9f874c847075ad666c7f3621c8c524af3b/haiku/_src/integration/bfloat16_test.py#L31
...except using float64. This involves enabling --jax_enable_x64
for the test & fixing outstanding errors.
Currently, MaxPool
and AvgPool
expect window shape and stride to be "Same rank as value or int
.". I think the layers should check the window shape and size and ensure the user conforms to this and throw an error (or perhaps a warning) if they do not.
Coming from other frameworks, users would be used to passing in window shapes/strides with the same rank as the number of spatial dimensions. For example, for a 2D maxpool, users would be used to passing in window_shape=(2, 2)
, which would resolve to (1, 1, 2, 2)
in Haiku which is likely not what the user intended. Throwing an error would force the user to pass in window_shape=(1, 2, 2, 1)
or window_shape=2
, and would save them debugging time.
I was looking at Haiku's dropout implementation and it seems like there's no difference between the use of dropout during training and during test, yet normally dropout is only done during training, is this difference intended behavior?
How could I add a constant state to my haiku module?
Specifically I would want something like this:
class MyModule(hk.Module):
def __init__(output_size, const, name):
if const = True:
self.b = hk.conts(jnp.ones(output_size)) //won't be updated when adding gradient
else:
self.b = jnp.zeros(output_size) //will get updated when adding gradient
Currently, there is no easy way to just call an initializer outside of hk.transform
. This can be worked around but it discourages using hk.next_rng_key
when writing custom initializers since whatever code depending on hk.next_rng_key
is now unusable without buying in to everything else. This is a minor annoyance when debugging an initializer since there is not way to retrieve sampled weights without wrappers. I could definitely imagine this becoming more of a pain in more complex settings that involve writing many initializers (admittedly, whether that is a plausible scenario is a different question).
This is definitely not a high priority thing but if there was a way to allow using hk.with_rng
outside of transform, it wouldn't be an issue anymore. If the diifferent parts of the frame stack backend are too interdependent to safely allow hk.with_rng
outside of hk.transform
, then it's probably not worth the effort, but it never hurts to ask!
After you open-sourced haiku, rlax recently, I started to learn about JAX and I love these projects.
However, whenever I share these open sources to deep learning communities and my friends, they don't seem to be convinced to use them. I believe you must have solid reasons for choosing JAX for the RL research framework, which I've tried to guess. So far, these are the advantages that I could think of.
But I'm still not fully aware of the logic behind it. Would it be possible for you to briefly explain the motivation behind using JAX for your research?
This variation on @tomhennigan's example tries to build a tree of module types.
It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context
, or if it is possible to recover the path from context
, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?
def init_and_build_module_tree(f):
"""
Decorated functions build a tree of module types alongside the parameters
Usage:
def f(x):
net = haiku.nets.MLP([300, 100, 10])
return net(x)
params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
params = tree.map_structure(transform_params, params, modules)
"""
def _init_and_build_module_tree(rng_key, *args, **kwargs):
module_types = []
def record_module_type(next_creator, shape, dtype, init, context):
module_types.append(type(context.module))
return next_creator(shape, dtype, init)
def with_creator(*aargs, **kkwargs):
with haiku.experimental.custom_creator(record_module_type):
return f(*aargs, **kkwargs)
params, _ = haiku.transform_with_state(with_creator).init(
rng_key,
*args,
**kwargs
)
module_tree = tree.unflatten_as(
params,
module_types
)
return params, module_tree
return _init_and_build_module_tree
Hey,
I want to create a Haiku compatible library which provides a Keras-like interface called Elegy. I love how easy it is to use Haiku, however, I would like users to code something like this:
module = hk.Sequential([
hk.Flatten(),
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
model = Model(
module,
loss=...,
metrics=dict(
accuracy=Accuracy(), # this is probably an hk.Module as well
),
optimizer=...,
)
Instead of checking if __init__
is called inside hk.transform
wouldn't it make more sense to check if __call__
is called inside of this context? If this is not possible because you are not able to intercept __call__
maybe having the user implement something like call
or apply
would be better?
in a sparsely gated mixture of experts where each expert has state (a memory), there is a KeyError if different experts are activated on apply
as happened to be activated on init
-- to solve this, you can pass an 'init' flag into your custom module and if it is True, then you just use all the experts on that call. If that's a memory issue, you can use them one by one. Just make sure init
hits all conditional branches of the modules with state
(let me know if there's an easier solution)
Hey, I use pyright
/ pylance
for type checking and they are pretty unhappy that hk.Module
doesn't define and abstract __call__
method, I get type errors all over the place when defining code that take arbitrary hk.Module
s. Given most of Haiku is already typed this would be a nice addition.
We currently don't explain what hk.jit
, hk.remat
, etc. are and why they exist. It would be good to extend the documentation with these.
We think the following should be a valid construct, since access to ThreadLocalState should be gated in try/excepts - but we don't have many tests for this.
We should!
try:
hk.transform(possibly_bad).apply(....)
except:
logging.info('lol bad')
hk.transform(known_good).apply(...)
Thanks for creating such a nice library!
My example use case in flax
:
nn = ... # flax.nn.Module
_, nn_params = nn.init(rng_key, data)
model = flax.nn.Model(nn, nn_params)
With this, i can call model(x)
and i can also access current params via model.params
.
For haiku
nn = ... #hk.Module
nn_params = nn.init(rng_key, data)
# i can do this
partial_fun = lambda data: nn.apply(nn_params, rng_key, data)
I can do partial_fun(x)
but because partial_fun
is a lambda, i can't access nn_params
from partial_fun
. Wondering if there's any workaround to achieve this in haiku
.
For more context, I am trying to integrate haiku
with numpyro
so that we can convert traditional NN into Bayesian NN. You can see this issue pyro-ppl/numpyro#705 for more context, or this example notebook
jax.remat
wraps all of its inputs with _foil_cse
.
When we do the state-threading in hk.stateful.remat
, the threaded-out state now is the output of _foil_cse
. Any downstream uses of this state now access the foil-cse'd param/state, rather than the original.
Example:
def f(x, ctxt):
return jnp.sin(x + ctxt[0]), ctxt
@jax.jit
def g(x):
ctxt = [x + i for i in range(2)]
x, ctxt = jax.remat(f)(x, ctxt)
return jnp.sin(x + ctxt[1])
g(1.).block_until_ready()
This results in HLO that looks like:
HloModule jit_g__1.46, is_scheduled=true
ENTRY jit_g__1.46 {
constant.2 = f32[]{:T(256)} constant(2)
constant = f32[]{:T(256)} constant(0)
constant.5 = f32[]{:T(256)} constant(1)
rng.2 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare.2 = pred[]{:T(256)E(32)} compare(rng.2, constant.2), direction=LT
rng.1 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare.1 = pred[]{:T(256)E(32)} compare(rng.1, constant.2), direction=LT
rng = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare = pred[]{:T(256)E(32)} compare(rng, constant.2), direction=LT
parameter.1 = f32[]{:T(256)} parameter(0), parameter_replication={false}
select = f32[]{:T(256)} select(compare, parameter.1, constant)
select.1 = f32[]{:T(256)} select(compare.1, parameter.1, constant)
add = f32[]{:T(256)} add(select, select.1)
sine = f32[]{:T(256)} sine(add)
add.6 = f32[]{:T(256)} add(parameter.1, constant.5)
select.2 = f32[]{:T(256)} select(compare.2, add.6, constant)
add.43 = f32[]{:T(256)} add(sine, select.2)
sine.44 = f32[]{:T(256)} sine(add.43)
ROOT tuple.45 = (f32[]{:T(256)}) tuple(sine.44)
}
Possible solutions:
stateful_fun
, especially during apply
.
hk.PRNGSequence
inside stateful_fun
so that RNG doesn't get threaded in/out.state
is only updated in-place for state that's actually been changed. JAX referential transparency makes this challenging for the case in which Haiku is not jitted but internal functions are via hk.jit
.hk.remat
on top of hk._src.lift
.Hello, haiku
team! Thanks a lot for making awesome haiku
.
I'm interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in #16 (comment). The pytrees ideally fits into the described use case. The user can create its own differentiable "vectors" and I would expect haiku
to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don't work at the moment.
In [58]: class S(hk.Module):
...: def __init__(self, x, y):
...: super().__init__()
...: # These are parameters:
...: self.x = x
...: self.y = y
...: def __repr__(self):
...: return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
...: def S_flatten(v):
...: children = (v.x, v.y)
...: aux_data = None
...: return (children, aux_data)
...: def S_unflatten(aux_data, children):
...: return S(*children)
...: register_pytree_node(S, S_flatten, S_unflatten)
...:
...:
...: def function(s):
...: return np.sqrt(s.x**2 * s.y**2)
...:
...: def loss(x):
...: s = S(1.0, 2.0)
...: a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros)
...: return jnp.sum([function(s) * a * x])
...:
...: x = jnp.array([2.0])
...: forward = hk.transform(loss)
...: key = jax.random.PRNGKey(42)
...: params = forward.init(key, x)
In [59]: params
Out[59]:
frozendict({
'~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}),
})
Thanks
I was wondering if there was a simple way to find the total number of parameters in a haiku model, apart from iterating through the different layers and counting them.
Hey, I was wondering if you plan to add this layer?
I am implementing the VGG16 architecture which is pretty straightforward, but when I run hk.transform
to obtain the initial parameters the call does not ever end.
My code:
import jax
import jax.numpy as np
import haiku as hk
from PIL import Image
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
def _make_layers():
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers.append(hk.MaxPool(window_shape=[2, 2],
strides=[2, 2],
padding="VALID"))
else:
conv2d = hk.Conv2D(v,
kernel_shape=[3, 3],
stride=1,
padding='SAME')
layers += [conv2d, jax.nn.relu]
in_channels = v
return layers
class VGG16(hk.Module):
def __init__(self, num_classes: int = 1000) -> None:
super(VGG16, self).__init__()
features = _make_layers()
classifier = [
hk.Flatten(),
hk.Linear(4096), jax.nn.relu,
hk.Linear(4096), jax.nn.relu,
hk.Linear(num_classes), jax.nn.softmax
]
self.vgg = hk.Sequential(features + classifier)
def __call__(self, x):
return self.vgg(x)
def _forward(image):
net = VGG16()
return net(image)
rng = jax.random.PRNGKey(0)
print('Loading image...')
im = Image.open('images/dog.jpeg').resize((224, 224))
im = np.array(im).astype('float32') / 255.
print('hk.transform...')
vgg_forward_fn = hk.transform(_forward)
print('This will never end..')
params = vgg_forward_fn.init(rng, np.expand_dims(im, 0))
I am on Windows 10 using WSL. I have just implemented the same architecture using the experimental stax
library, and the initialization ends up with just 1 second.
Any idea on what I am doing wrong?
Thanks in advance
To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a hk.add_loss
utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output from apply
. I guess that this would require an additional flag to hk.transform
and friends.
Hi Haiku Team! Thank you for all your work on Haiku.
I'm interested in writing a layer which takes a function as an argument and produces custom gradients. (For context i'm implementing a method to find the stationary point of a function).
A toy example of my implementation in Pure JAX is below and the full implementation here:
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def g(x: jnp.ndarray, fun: Callable):
return jax.lax.stop_gradient(fun(x))
def g_fwd(x, fun):
return g(x, fun), x
def g_bwd(fun, res, grad):
x, = res
return fun(x)
g.defvjp(g_fwd, g_bwd)
My question: What is the best way to implement this in Haiku?
I see some approaches:
Create a JAX function which takes a haiku.Module
as an argument (in place of the function). This currently leads to issues with the Module in the backward pass (it doesn't seem to be transformed).
Create a haiku.Module
with custom gradients. However this is a pure function, so creating a Module feels wrong as it doesn't require Parameters (but may require the Haiku State of another module?).
Use haiku.to_module(f)
. In this approach I'd use get_parameters
to access the states of the input function and potentially use some naming conventions to make sure I have the correct scope. I imagine this is the best approach (and maybe very similar to the first approach) - but I really can't find much documentation on how naming variables or accessing them really works!
Would it be possible to share any prior art on how to implement any of these approaches?
import haiku as hk
import jax
import jax.numpy as jnp
from functools import partial
from typing import Callable
def build_net(output_size):
def forward_fn(x: jnp.ndarray) -> jnp.ndarray:
linear = hk.Linear(output_size, name='l1')
x = linear(x)
return g(x, linear)
return forward_fn
b_size, s_size, h_size = 1, 2, 3
input = jnp.ones((b_size, s_size, h_size))
rng = jax.random.PRNGKey(42)
net = build_net(h_size)
net = hk.transform(net)
params = net.init(rng, input)
def loss_fn(params, rng, x):
return jnp.sum(net.apply(params, rng, x))
print(jax.grad(loss_fn)(params, rng, input))
Gives:
15 def g_bwd(fun, res, grad):
16 input, = res
---> 17 return fun(input)
18
19 g.defvjp(g_fwd, g_bwd)
/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
299 if not base.frame_stack:
300 raise ValueError(
--> 301 "All `hk.Module`s must be initialized inside an `hk.transform`.")
302
ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.
It appears that the manifest is missing at least one file necessary to build
from the sdist for version 0.0.1b0. You're in good company, about 5% of other
projects updated in the last year are also missing files.
+ /tmp/venv/bin/pip3 wheel --no-binary dm-haiku -w /tmp/ext dm-haiku==0.0.1b0
Looking in indexes: http://10.10.0.139:9191/root/pypi/+simple/
Collecting dm-haiku==0.0.1b0
Downloading http://10.10.0.139:9191/root/pypi/%2Bf/f19/fdaf8281b7fb0/dm-haiku-0.0.1b0.tar.gz (121 kB)
ERROR: Command errored out with exit status 1:
command: /tmp/venv/bin/python3 -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"'; __file__='"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /tmp/pip-wheel-8pyg10w7/dm-haiku/pip-egg-info
cwd: /tmp/pip-wheel-8pyg10w7/dm-haiku/
Complete output (7 lines):
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 56, in <module>
install_requires=_parse_requirements('requirements.txt'),
File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 33, in _parse_requirements
with open(requirements_txt_path) as fp:
FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
I left an issue on the JAX page here about omni-staging breaking LSTMs, but it was closed since @tomhennigan mentioned it was fixed in a recent PR, but it seems like it is not. In summary, JIT compiling the gradient of the loss with respect to an LSTM with static unrolling doesn't ever finish compiling. It keeps using more and more RAM until it crashes.
Here is a gist reproducing the issue (with Haiku and JAX at head). The colab can't be run sequentially: you must run the top portion then one of "With OmniStaging" or "Without OmniStaging" (and restart to run the other one).
The code works when I use dynamic_unroll
instead of static_unroll
, but this is just a temporary workaround as I would like to use static_unroll
.
This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform
? I took a look at hk.Module
and ModuleMetaClass
but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.
For example, I would have expected this to be possible:
linear = hk.Linear(10) # currently not allowed
def forward(x):
return linear(x)
model = hk.transform(forward)
Concretely, I'm curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform
calls?
I wouldn't be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform
, but, if not, I think it would be more intuitive to allow this kind of usage.
I've noticed _src
has become quite large. I think eventually splitting it up into folders makes more sense. We could have:
nn
initializers
regularizers
losses
metrics
Hi Haiku team,
Thanks for opensourcing such a great library!
It looks like hk.ResetCore requires the leading dimension of the inputs to be the batch dimension: https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L638. However, other RNN cores, like LSTM (https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L264), do not have such a requirement. In fact, I found this mismatch when I was using an RNN without a batch dimension in the inputs.
Could you perhaps change hk.ResetCore so that it also works with inputs without a batch dimension?
Thanks!
Zeyu
Using the current version of master 66f9c69 of Haiku, I am getting the following error on Colab
AttributeError Traceback (most recent call last)
<ipython-input-3-3a9e6adbfff5> in <module>()
----> 1 import haiku as hk
/usr/local/lib/python3.6/dist-packages/haiku/__init__.py in <module>()
17
18 from haiku import data_structures
---> 19 from haiku import experimental
20 from haiku import initializers
21 from haiku import nets
/usr/local/lib/python3.6/dist-packages/haiku/experimental.py in <module>()
22 from haiku._src.base import custom_getter
23 from haiku._src.base import ParamContext
---> 24 from haiku._src.dot import to_dot
25 from haiku._src.lift import lift
26 from haiku._src.module import profiler_name_scopes
/usr/local/lib/python3.6/dist-packages/haiku/_src/dot.py in <module>()
23
24 from haiku._src import data_structures
---> 25 from haiku._src import module
26 from haiku._src import utils
27 import jax
/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in <module>()
26 from haiku._src import base
27 from haiku._src import data_structures
---> 28 from haiku._src import named_call
29 from haiku._src import utils
30 import jax.numpy as jnp
/usr/local/lib/python3.6/dist-packages/haiku/_src/named_call.py in <module>()
29
30 # Registering named call as a primitive
---> 31 named_call_p = core.CallPrimitive('named_call')
32 # named_call is implemented as a plain core.call and only diverges
33 # under compilation (see named_call_translation_rule)
AttributeError: module 'jax.core' has no attribute 'CallPrimitive'
I believe that's because Haiku now requires jax>=0.1.71
, while the version by default on Colab is jax==0.1.69
. CallPrimitive
was introduced in jax 0.1.71.
https://github.com/google/jax/blob/1545a29e6d69a7b3c7fdf9a49b38004759a9fbfa/jax/core.py#L1106-L1115
To reproduce (inside a Colab):
import jax
print(jax.__version__) # 0.1.69
!pip install -q git+https://github.com/deepmind/dm-haiku
import haiku as hk
Run !pip install -q --upgrade jax jaxlib
first in your Colab to fix this issue.
import os
os.environ["JAX_ENABLE_X64"] = "1"
import jax
import haiku as hk
import numpy as np
@hk.transform
def f(x):
return hk.Linear(4)(x)
f32_data = np.zeros((4, 8), dtype=np.float32)
p = f.init(jax.random.PRNGKey(428), f32_data)
print(jax.tree_map(lambda t: t.dtype, p))
f32_params = jax.tree_map(lambda t: t.astype(np.float32), p)
print(f.apply(f32_params, f32_data).dtype)
Prints:
frozendict({
'linear': frozendict({'b': dtype('float32'), 'w': dtype('float64')}),
})
dtype('float32')
Hopefully the bfloat16 compatibility work means we get most of this for free and that we only need to port the initializers.
Let's say I want to iterate through all modules inside an hk model and replace all hn.Linear
s with my own custom Module
or monkey-patch some of their properties. Does haiku currently support something along these lines?
I noticed earlier today that Haiku has SpectralNormalization -- very cool!
I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.
Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc
My question: how can I implement this in Haiku?
jax.vjp
on Module? I'm guessing (though to be honest I haven't checked yet) that normal JAX function would break, given the way that Haiku adds mutable state.Hi:
I'm computing the Jacobian of my model on every step. The input is a z-dimensional vector and model's output is (bz,ch,h,w), it's a decoder. I expected the output of the Jacobian to be (bz,ch,h,w,z) and it is when i compute the jacobian of a row in my batch. However, when I stick in the whole batch, the dimensions become: (bz,ch,h,w,bz,z). Why the extra bz?
I checked the Jacobian output values for the i^th row, and (i,ch,h,w,j,z) is zero for all j != i. And it has the correct values when j = i. I can still use this; however, I have to take the extra step of removing the zero outputs.
Here's my code for the Jacobian:
decoder_jac = hk.transform(lambda z: jacfwd(Decoder())(z))
Hi! I'm one of the developers of Haiku and a board member of Haiku, Inc.
We have trademark on the name Haiku around our open-source operating system, as well as a registered logo mark on our Haiku logo.
Haiku has been in development for nearly 20 years now, and usage of the Haiku name for other open source software projects could create confusion.
Please feel free to reach out to Haiku, Inc. if you have any questions.
I'm trying to figure out the differences between pytorch's LSTM and Haiku's LSTM. In Haiku, the LSTM expects the input to be a rank 1 or rank 2 tensor. However, in pytorch, the input is expected to be a rank 3 tensor. I was hoping to get some clarification on how one would expect to get the same behavior in Haiku's lstms compared to torch's lstms.
Some additional information -- if this is the case, it seems like the LSTM in Haiku is significantly slower than torch. It's very likely that I've made a mistake somewhere here.
Also, would it be correct in saying that haiku's static_unroll
+ lstm
is the same as pytorch's lstm
? In which case, isn't Haiku's lstm technically an LSTMcell?
Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling .init
I obtain the slow and fast weights like this
params = f.init(...)
fast_weights = params["fast_weights"]
And just before calling the .apply
function I want to merge them (the fast weights will be modified). My first attempt was to use the hk.data_structures.merge
method like this
params = hk.data_structures.merge(params, fast_weights)
output = f.apply(params, rng, inputs)
But this raises the exception AttributeError: 'DeviceArray' object has no attribute 'items'
. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.
Thanks!
module.params_dict()
can behave in surprising ways:
def f(x):
mod = hk.Linear(8)
print(mod.params_dict()) # empty during init, full during apply
sequential = hk.Sequential([mod])
print(sequential.params_dict()) # always empty
out = sequential(x)
print(sequential.params_dict()) # no longer empty
return out
net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))
Prints:
{}
{}
{...}
{...}
{}
{...}
We should clean up & clearly define the desired semantics of params_dict()
.
The JAX team is going to remove optix from the library now that optax exists, can I open a PR changing all imports from jax.experimental.optix
to optax
?
Seems like unintended behavior.
/usr/local/lib/python3.6/dist-packages/haiku/_src/embed.py in __init__(self, vocab_size, embed_dim, embedding_matrix, w_init, lookup_style, name)
73 """
74 super(Embed, self).__init__(name=name)
---> 75 if not embedding_matrix and not (vocab_size and embed_dim):
76 raise ValueError(
77 "hk.Embed must be supplied either with an initial `embedding_matrix` "
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Right now it seems one needs bazel to install haiku (even using pip)
Use collections.abc
instead. Also collections.abc.KeysView
is used in data_structures.py below but the import statement at top is just collections. Accessing collections.abc
with just import collections is not supported.
haiku/_src/utils.py
121: not isinstance(element, collections.Sequence)):
haiku/_src/data_structures.py
80:class KeysOnlyKeysView(collections.abc.KeysView):
haiku/_src/layer_norm.py
79: elif (isinstance(axis, collections.Iterable) and
haiku/_src/stateful.py
48: if isinstance(v, collections.Mapping):
The default initialization for linear and convolutional modules seems to be Glorot initialization, but for the commonly used ReLU activation function He initialization is superior, while only requiring a quick change to the stddev definition, should we implement better defaults?
I know that there are many initialization schemes, I only suggest it as it would't be computationally expensive and would also be only a minor code change.
Hi there,
I noticed there have been some changes to FlatMapping
.
I can imagine that you don't really see the need for deepcopying a FlatMapping
as it's supposed to be immutable. But just so you know, deepcopy doesn't work anymore:
from haiku._src.data_structures import FlatMapping
m = FlatMapping.from_mapping({'foo': 'bar'})
deepcopy(m) # raises TypeError: can't pickle jaxlib.pytree.PyTreeDef objects
P.S. I stumbled upon this because I'm deriving a subclass with limited mutability from FlatMapping (whose leaves are all DeviceArrays). I'm using deepcopy
for my target-network / behavior-policy weights.
I added a custom implementation to my derived class:
class Foo(FlatMapping):
...
def __deepcopy__(self, memo):
leaves, treedef = self.flatten()
return self.__class__((deepcopy(leaves), treedef))
Also.. thanks for the speed-up in FlatMapping!
for concurrent.futures.ProcessPoolExecutor or multiprocessing.Pool, it's necessary to pickle objects to send em to other processes to run. That doesn't work with transformed haiku stuff:
import pickle
import haiku as hk
def forward(x): return hk.Linear(x)
stateful = hk.transform_with_state(forward)
pickle.dumps(stateful)
AttributeError: Can't pickle local object 'transform_with_state.<locals>.init_fn'
stateless = hk.transform(forward)
pickle.dumps(stateless)
AttributeError: Can't pickle local object 'without_state.<locals>.init_fn'
How could we pickle haiku models?
Should we take a different approach?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.