Giter Club home page Giter Club logo

dm-haiku's Issues

Stateful functions and target networks?

Hi there,

I'd like to ask for some advice.

Let's say I've got a function approximator for a q-function that uses hk.{get,set}_state() along with hk.transform_with_state(). This means that my function approximator consists of a triplet func, params, state.

I would like to keep a separate copy of this function approximator, i.e. a target network. This means that I keep separate copy of the params.

Now my question is, would you recommend I also keep a separate copy of the state? And if so, how do we ensure that the usual smooth updates make sense? (e.g. variance typically can't be updated this way unless you map it to the 2nd moment first)

Announcing Elegy!

Hey, I created this framework called Elegy with a couple of friends, its Keras-like framework based on Haiku. The experience of creating a higher-level API for training using Haiku has been interesting, we've had to wrap some of Haiku's modules and functions to enable additional features but the overall experience has been very positive! It would be great if some of these features eventually went back to Haiku in the future if you are interested.

https://poets-ai.github.io/elegy/

Feel free to close the issue.

MyModule.__init__ runs every time 'apply' is called

import haiku as hk


class MyModule(hk.Module):
    def __init__(self):
        super().__init__(name="MyModule")
        print("MyModule.__init__ ran")

    def __call__(self, x):
        print("MyModule.__call__ ran")
        return x + 1


def forward(x):
    return MyModule()(x)


rng = hk.PRNGSequence(420)
fwd = hk.transform(forward)
params = fwd.init(next(rng), 1)
x = 0
for _ in range(10):
    x = fwd.apply(params, x)
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran
MyModule.__init__ ran
MyModule.__call__ ran

is this intended? It caused my module to break eventually because it pulled different hyperparameters at different invocations

Feature Request: some way to pass hyperparameters out of transforms

Do you want to hand tune your models? Most folks don’t because it’s slow. So we get into hyperparameter optimization.

In Haiku, we define a model like

def forward(x):
   return MyModule(hyperparameters)(x)

model = hk.transform(forward)

Where hyperparameters is a dict or class to record decisions modules make when they “pull” choices like “number of layers” “latent code dimensionality” “block 2 n_heads” ... we cannot enumerate all possible hyperparameters ahead of time because certain ones depend on the value of earlier ones and there are just too many combinations

Those hyperparameters become “stuck” inside the forward function, unless we make hyperparameters a mutable datastructure and mutate it inside all our MyModule.init methods. Then, the forward function has deeply nested side effects when defined, so you cannot share the decisions across a group of agents because they will mutate shared state, and the haiku tracer needs to know to ignore these init hyperparameter mutations, which I’m assuming it does but not 100% sure. Mutating hyperparameters at definition time is also confusing for engineers because there’s many “init” steps at different times

How can we make an elegant way for “forward” methods to pull hyperparameters and return their decisions?

JAX version in Colab

Right now colab uses jax 0.1.69 and so it fails to import haiku.
I believe it would be better to add both jax and jaxlib as normal requirements, so when someone installs they get the correct versions of the dependencies.

Fatal Python error: Aborted when running mnist.py example

Hi there! I've been trying to get familiar with the library by running some examples in the examples/ folder. My environment was set up according to the instructions on https://github.com/google/jax#installation and https://github.com/deepmind/dm-haiku#installation.

On running the mnist.py example with TensorFlow 2.1.0, a Fatal Python error: Aborted occurs. The full error message is as below:

2020-05-11 20:08:28.772679: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772772: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64
2020-05-11 20:08:28.772790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
I0511 20:08:29.945827 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:29.947907 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:29.948163 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:29.948284 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.869276 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.870402 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.870607 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.870710 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split train, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.915239 140124110780224 dataset_info.py:361] Load dataset info from /home/tedmund/tensorflow_datasets/mnist/3.0.1
I0511 20:08:30.916428 140124110780224 dataset_info.py:401] Field info.citation from disk and from code do not match. Keeping the one from code.
I0511 20:08:30.916637 140124110780224 dataset_builder.py:283] Reusing dataset mnist (/home/tedmund/tensorflow_datasets/mnist/3.0.1)
I0511 20:08:30.916740 140124110780224 dataset_builder.py:479] Constructing tf.data.Dataset for split test, from /home/tedmund/tensorflow_datasets/mnist/3.0.1
2020-05-11 20:08:32.182307: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:236] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2020-05-11 20:08:32.182349: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms) 
Fatal Python error: Aborted

Current thread 0x00007f712fd8f740 (most recent call first):
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 156 in compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jaxlib/xla_client.py", line 576 in Compile
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 197 in xla_primitive_callable
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/interpreters/xla.py", line 166 in apply_primitive
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/core.py", line 199 in bind
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 626 in dot_general
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/lax/lax.py", line 564 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 2484 in dot
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 161 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/basic.py", line 120 in __call__
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/module.py", line 155 in wrapped
  File "mnist.py", line 41 in net_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 271 in init_fn
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/haiku/_src/transform.py", line 106 in init_fn
  File "mnist.py", line 112 in main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 250 in _run_main
  File "/home/tedmund/anaconda3/envs/tf2/lib/python3.6/site-packages/absl/app.py", line 299 in run
  File "mnist.py", line 131 in <module>
Aborted (core dumped)

One solution I've found to this is a more commonplace solution when using TensorFlow, by inserting the code:

from tensorflow.compat.v1 import ConfigProto, InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.7
sess = InteractiveSession(config=config)

However, this kind of defeats the purpose if one is simply trying to use JAX/NumPy instead of TensorFlow. Not sure what else I can provide to help, please do let me know!

Is there a good way to save/load & compress/decompress model weights?

Hey- This is Chris.
I'm using this open-source for my project.

https://github.com/chris-chris/haiku-scalable-example

Since I'm new to JAX and haiku, I have some questions.

Is there a good way to save/load & compress/decompress & serialize model weights?

  • save/load model (network only or weight only)
  • compress/decompress weights
  • serialize

I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?

Thanks!

Disable __setattr__ in frozendict

There can be some strange behavior in frozendict due to attr assignment:

import haiku as hk

frozen = hk.data_structures.to_immutable_dict({'a': 'foo'})
print(frozen['a'], frozen.a)  # --> foo foo
frozen.a = 'bar'              # should raise error
print(frozen['a'], frozen.a)  # --> foo bar

My suggestion would be to simply disallow any setattr on frozendict.

Also, as a side note, please expose frozendict to the public api. Right now, the only way to access it is through haiku._src.data_structures.frozendict.

Raise error when user passes bad window/stride for pooling

Currently, MaxPool and AvgPool expect window shape and stride to be "Same rank as value or int.". I think the layers should check the window shape and size and ensure the user conforms to this and throw an error (or perhaps a warning) if they do not.

Coming from other frameworks, users would be used to passing in window shapes/strides with the same rank as the number of spatial dimensions. For example, for a 2D maxpool, users would be used to passing in window_shape=(2, 2), which would resolve to (1, 1, 2, 2) in Haiku which is likely not what the user intended. Throwing an error would force the user to pass in window_shape=(1, 2, 2, 1) or window_shape=2, and would save them debugging time.

Dealing with conditionally constant state

How could I add a constant state to my haiku module?
Specifically I would want something like this:

class MyModule(hk.Module):
  def __init__(output_size, const, name):
    if const = True:
      self.b = hk.conts(jnp.ones(output_size)) //won't be updated when adding gradient
    else:
      self.b = jnp.zeros(output_size) //will get updated when adding gradient

Allow `hk.next_rng_key` calls outside of `hk.transform` with `hk.with_rng`

Currently, there is no easy way to just call an initializer outside of hk.transform. This can be worked around but it discourages using hk.next_rng_key when writing custom initializers since whatever code depending on hk.next_rng_key is now unusable without buying in to everything else. This is a minor annoyance when debugging an initializer since there is not way to retrieve sampled weights without wrappers. I could definitely imagine this becoming more of a pain in more complex settings that involve writing many initializers (admittedly, whether that is a plausible scenario is a different question).

This is definitely not a high priority thing but if there was a way to allow using hk.with_rng outside of transform, it wouldn't be an issue anymore. If the diifferent parts of the frame stack backend are too interdependent to safely allow hk.with_rng outside of hk.transform, then it's probably not worth the effort, but it never hurts to ask!

I'm curious about the reason why you use JAX.

After you open-sourced haiku, rlax recently, I started to learn about JAX and I love these projects.
However, whenever I share these open sources to deep learning communities and my friends, they don't seem to be convinced to use them. I believe you must have solid reasons for choosing JAX for the RL research framework, which I've tried to guess. So far, these are the advantages that I could think of.

  • Pure function without side effects
  • Performance enhancement using XLA
  • Lightweight and NEAT
  • Fully utilizing vectorization and parallelization
  • Flexible architecture for a complicated distributed learning system

But I'm still not fully aware of the logic behind it. Would it be possible for you to briefly explain the motivation behind using JAX for your research?

Faithfully reconstruct tree from context

This variation on @tomhennigan's example tries to build a tree of module types.

It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context, or if it is possible to recover the path from context, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?

def init_and_build_module_tree(f):
    """
    Decorated functions build a tree of module types alongside the parameters

    Usage:
      def f(x):
        net = haiku.nets.MLP([300, 100, 10])
        return net(x)

      params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
      params = tree.map_structure(transform_params, params, modules)
    """

    def _init_and_build_module_tree(rng_key, *args, **kwargs):
        module_types = []

        def record_module_type(next_creator, shape, dtype, init, context):
            module_types.append(type(context.module))
            return next_creator(shape, dtype, init)

        def with_creator(*aargs, **kkwargs):
            with haiku.experimental.custom_creator(record_module_type):
                return f(*aargs, **kkwargs)

        params, _ = haiku.transform_with_state(with_creator).init(
            rng_key,
            *args,
            **kwargs
        )

        module_tree = tree.unflatten_as(
            params,
            module_types
        )

        return params, module_tree

    return _init_and_build_module_tree

Calling __init__ outside transform

Hey,

I want to create a Haiku compatible library which provides a Keras-like interface called Elegy. I love how easy it is to use Haiku, however, I would like users to code something like this:

module = hk.Sequential([
      hk.Flatten(),
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
])

model = Model(
      module,
      loss=...,
      metrics=dict(
            accuracy=Accuracy(),  # this is probably an hk.Module as well
      ),
      optimizer=...,
)

Instead of checking if __init__ is called inside hk.transform wouldn't it make more sense to check if __call__ is called inside of this context? If this is not possible because you are not able to intercept __call__ maybe having the user implement something like call or apply would be better?

KeyError in conditional state (solved)

in a sparsely gated mixture of experts where each expert has state (a memory), there is a KeyError if different experts are activated on apply as happened to be activated on init -- to solve this, you can pass an 'init' flag into your custom module and if it is True, then you just use all the experts on that call. If that's a memory issue, you can use them one by one. Just make sure init hits all conditional branches of the modules with state

(let me know if there's an easier solution)

hk.Module should be an abstract Callable

Hey, I use pyright / pylance for type checking and they are pretty unhappy that hk.Module doesn't define and abstract __call__ method, I get type errors all over the place when defining code that take arbitrary hk.Modules. Given most of Haiku is already typed this would be a nice addition.

In haiku, is there something equivalent to flax.nn.Model?

Thanks for creating such a nice library!

My example use case in flax:

nn = ... # flax.nn.Module
 _, nn_params = nn.init(rng_key, data)
model = flax.nn.Model(nn, nn_params) 

With this, i can call model(x) and i can also access current params via model.params.

For haiku

nn = ... #hk.Module 
nn_params = nn.init(rng_key, data)
# i can do this
partial_fun = lambda data: nn.apply(nn_params, rng_key, data)

I can do partial_fun(x) but because partial_fun is a lambda, i can't access nn_params from partial_fun. Wondering if there's any workaround to achieve this in haiku.

For more context, I am trying to integrate haiku with numpyro so that we can convert traditional NN into Bayesian NN. You can see this issue pyro-ppl/numpyro#705 for more context, or this example notebook

hk.stateful.remat generates excess un-pruneable HLO

jax.remat wraps all of its inputs with _foil_cse.

When we do the state-threading in hk.stateful.remat, the threaded-out state now is the output of _foil_cse. Any downstream uses of this state now access the foil-cse'd param/state, rather than the original.

Example:

def f(x, ctxt):
  return jnp.sin(x + ctxt[0]), ctxt

@jax.jit
def g(x):
  ctxt = [x + i for i in range(2)]
  x, ctxt = jax.remat(f)(x, ctxt)
  return jnp.sin(x + ctxt[1])

g(1.).block_until_ready()

This results in HLO that looks like:

HloModule jit_g__1.46, is_scheduled=true

ENTRY jit_g__1.46 {
  constant.2 = f32[]{:T(256)} constant(2)
  constant = f32[]{:T(256)} constant(0)
  constant.5 = f32[]{:T(256)} constant(1)
  rng.2 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.2 = pred[]{:T(256)E(32)} compare(rng.2, constant.2), direction=LT
  rng.1 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.1 = pred[]{:T(256)E(32)} compare(rng.1, constant.2), direction=LT
  rng = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare = pred[]{:T(256)E(32)} compare(rng, constant.2), direction=LT
  parameter.1 = f32[]{:T(256)} parameter(0), parameter_replication={false}
  select = f32[]{:T(256)} select(compare, parameter.1, constant)
  select.1 = f32[]{:T(256)} select(compare.1, parameter.1, constant)
  add = f32[]{:T(256)} add(select, select.1)
  sine = f32[]{:T(256)} sine(add)
  add.6 = f32[]{:T(256)} add(parameter.1, constant.5)
  select.2 = f32[]{:T(256)} select(compare.2, add.6, constant)
  add.43 = f32[]{:T(256)} add(sine, select.2)
  sine.44 = f32[]{:T(256)} sine(add.43)
  ROOT tuple.45 = (f32[]{:T(256)}) tuple(sine.44)
}

Possible solutions:

  1. Reduce the amount of state-motion in & out of stateful_fun, especially during apply.
    • Params are immutable during apply, don't thread them in/out.
    • Pre-split the RNG and populate a new hk.PRNGSequence inside stateful_fun so that RNG doesn't get threaded in/out.
    • state is only updated in-place for state that's actually been changed. JAX referential transparency makes this challenging for the case in which Haiku is not jitted but internal functions are via hk.jit.
  2. Rebuild hk.remat on top of hk._src.lift.

Support for custom pytrees

Hello, haiku team! Thanks a lot for making awesome haiku.

I'm interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in #16 (comment). The pytrees ideally fits into the described use case. The user can create its own differentiable "vectors" and I would expect haiku to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don't work at the moment.

Failing example

In [58]: class S(hk.Module):
    ...:   def __init__(self, x, y):
    ...:     super().__init__()
    ...:     # These are parameters:
    ...:     self.x = x
    ...:     self.y = y
    ...:   def __repr__(self):
    ...:     return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
    ...: def S_flatten(v):
    ...:   children = (v.x, v.y)
    ...:   aux_data = None
    ...:   return (children, aux_data)
    ...: def S_unflatten(aux_data, children):
    ...:   return S(*children)
    ...: register_pytree_node(S, S_flatten, S_unflatten)
    ...:
    ...:
    ...: def function(s):
    ...:   return np.sqrt(s.x**2 * s.y**2)
    ...:
    ...: def loss(x):
    ...:   s = S(1.0, 2.0)
    ...:   a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros)
    ...:   return jnp.sum([function(s) * a * x])
    ...:
    ...: x = jnp.array([2.0])
    ...: forward = hk.transform(loss)
    ...: key = jax.random.PRNGKey(42)
    ...: params = forward.init(key, x)
In [59]: params
Out[59]:
frozendict({
  '~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}),
})

Thanks

Total Number of Parameters in Haiku Module

I was wondering if there was a simple way to find the total number of parameters in a haiku model, apart from iterating through the different layers and counting them.

`apply` method really slow on CPU

I am implementing the VGG16 architecture which is pretty straightforward, but when I run hk.transform to obtain the initial parameters the call does not ever end.

My code:

import jax
import jax.numpy as np

import haiku as hk

from PIL import Image

cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

def _make_layers():
    layers = []
    in_channels = 3

    for v in cfg:
        if v == 'M':
            layers.append(hk.MaxPool(window_shape=[2, 2], 
                                     strides=[2, 2],
                                     padding="VALID"))
        else:
            conv2d = hk.Conv2D(v, 
                               kernel_shape=[3, 3], 
                               stride=1, 
                               padding='SAME')
            layers += [conv2d, jax.nn.relu]
            in_channels = v

    return layers

class VGG16(hk.Module):
    
    def __init__(self, num_classes: int = 1000) -> None:
        super(VGG16, self).__init__()
        features = _make_layers()
        classifier = [
                hk.Flatten(),
                hk.Linear(4096), jax.nn.relu,
                hk.Linear(4096), jax.nn.relu,
                hk.Linear(num_classes), jax.nn.softmax
        ]
        self.vgg = hk.Sequential(features + classifier)

    def __call__(self, x):
        return self.vgg(x)

def _forward(image):
    net = VGG16()
    return net(image)

rng = jax.random.PRNGKey(0)

print('Loading image...')
im = Image.open('images/dog.jpeg').resize((224, 224))
im = np.array(im).astype('float32') / 255.

print('hk.transform...')
vgg_forward_fn = hk.transform(_forward)

print('This will never end..')
params = vgg_forward_fn.init(rng, np.expand_dims(im, 0))

I am on Windows 10 using WSL. I have just implemented the same architecture using the experimental stax library, and the initialization ends up with just 1 second.

Any idea on what I am doing wrong?

Thanks in advance

hk.add_loss

To enable users to easily create per layer weight and activity regularizers plus other forms of losses created by intermediate layers it would be very useful if haiku had a hk.add_loss utility that when called within a transform it would append a loss to a list of losses which the user could later retrieve as an additional output from apply. I guess that this would require an additional flag to hk.transform and friends.

Write Haiku module with custom gradient

Hi Haiku Team! Thank you for all your work on Haiku.

I'm interested in writing a layer which takes a function as an argument and produces custom gradients. (For context i'm implementing a method to find the stationary point of a function).

A toy example of my implementation in Pure JAX is below and the full implementation here:

@partial(jax.custom_vjp, nondiff_argnums=(1,))
def g(x: jnp.ndarray, fun: Callable):
    return jax.lax.stop_gradient(fun(x))

def g_fwd(x, fun):
    return g(x, fun), x

def g_bwd(fun, res, grad):
    x, = res
    return fun(x)

g.defvjp(g_fwd, g_bwd)

My question: What is the best way to implement this in Haiku?

I see some approaches:

  • Create a JAX function which takes a haiku.Module as an argument (in place of the function). This currently leads to issues with the Module in the backward pass (it doesn't seem to be transformed).

  • Create a haiku.Module with custom gradients. However this is a pure function, so creating a Module feels wrong as it doesn't require Parameters (but may require the Haiku State of another module?).

  • Use haiku.to_module(f). In this approach I'd use get_parameters to access the states of the input function and potentially use some naming conventions to make sure I have the correct scope. I imagine this is the best approach (and maybe very similar to the first approach) - but I really can't find much documentation on how naming variables or accessing them really works!

Would it be possible to share any prior art on how to implement any of these approaches?

Approach 1
import haiku as hk
import jax
import jax.numpy as jnp
from functools import partial
from typing import Callable

def build_net(output_size):
    def forward_fn(x: jnp.ndarray) -> jnp.ndarray:
        linear = hk.Linear(output_size, name='l1')
        x = linear(x)
        return g(x, linear)
    return forward_fn

b_size, s_size, h_size = 1, 2, 3 
input = jnp.ones((b_size, s_size, h_size))
rng = jax.random.PRNGKey(42)
net = build_net(h_size)
net = hk.transform(net)
params = net.init(rng, input)

def loss_fn(params, rng, x):
    return jnp.sum(net.apply(params, rng, x))

print(jax.grad(loss_fn)(params, rng, input))

Gives:

     15 def g_bwd(fun, res, grad):
     16     input, = res
---> 17     return fun(input)
     18 
     19 g.defvjp(g_fwd, g_bwd)

/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in wrapped(self, *args, **kwargs)
    299     if not base.frame_stack:
    300       raise ValueError(
--> 301           "All `hk.Module`s must be initialized inside an `hk.transform`.")
    302 

ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

Missing files in sdist

It appears that the manifest is missing at least one file necessary to build
from the sdist for version 0.0.1b0. You're in good company, about 5% of other
projects updated in the last year are also missing files.

+ /tmp/venv/bin/pip3 wheel --no-binary dm-haiku -w /tmp/ext dm-haiku==0.0.1b0
Looking in indexes: http://10.10.0.139:9191/root/pypi/+simple/
Collecting dm-haiku==0.0.1b0
  Downloading http://10.10.0.139:9191/root/pypi/%2Bf/f19/fdaf8281b7fb0/dm-haiku-0.0.1b0.tar.gz (121 kB)
    ERROR: Command errored out with exit status 1:
     command: /tmp/venv/bin/python3 -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"'; __file__='"'"'/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /tmp/pip-wheel-8pyg10w7/dm-haiku/pip-egg-info
         cwd: /tmp/pip-wheel-8pyg10w7/dm-haiku/
    Complete output (7 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 56, in <module>
        install_requires=_parse_requirements('requirements.txt'),
      File "/tmp/pip-wheel-8pyg10w7/dm-haiku/setup.py", line 33, in _parse_requirements
        with open(requirements_txt_path) as fp:
    FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    ----------------------------------------
ERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.

Omnistaging (still) breaks LSTM Model

I left an issue on the JAX page here about omni-staging breaking LSTMs, but it was closed since @tomhennigan mentioned it was fixed in a recent PR, but it seems like it is not. In summary, JIT compiling the gradient of the loss with respect to an LSTM with static unrolling doesn't ever finish compiling. It keeps using more and more RAM until it crashes.

Here is a gist reproducing the issue (with Haiku and JAX at head). The colab can't be run sequentially: you must run the top portion then one of "With OmniStaging" or "Without OmniStaging" (and restart to run the other one).

The code works when I use dynamic_unroll instead of static_unroll, but this is just a temporary workaround as I would like to use static_unroll.

Allow creating module instances outside hk.transform

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform? I took a look at hk.Module and ModuleMetaClass but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.

For example, I would have expected this to be possible:

linear = hk.Linear(10)  # currently not allowed

def forward(x):
  return linear(x)

model = hk.transform(forward)

Concretely, I'm curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform calls?

I wouldn't be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform, but, if not, I think it would be more intuitive to allow this kind of usage.

Folder Structure

I've noticed _src has become quite large. I think eventually splitting it up into folders makes more sense. We could have:

  • nn
  • initializers
  • regularizers
  • losses
  • metrics

hk.ResetCore Requires A Batch Dimension in Inputs

Hi Haiku team,

Thanks for opensourcing such a great library!

It looks like hk.ResetCore requires the leading dimension of the inputs to be the batch dimension: https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L638. However, other RNN cores, like LSTM (https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/recurrent.py#L264), do not have such a requirement. In fact, I found this mismatch when I was using an RNN without a batch dimension in the inputs.

Could you perhaps change hk.ResetCore so that it also works with inputs without a batch dimension?

Thanks!
Zeyu

Jax version upgrade (AttributeError: CallPrimitive)

Using the current version of master 66f9c69 of Haiku, I am getting the following error on Colab

AttributeError                            Traceback (most recent call last)
<ipython-input-3-3a9e6adbfff5> in <module>()
----> 1 import haiku as hk

/usr/local/lib/python3.6/dist-packages/haiku/__init__.py in <module>()
     17 
     18 from haiku import data_structures
---> 19 from haiku import experimental
     20 from haiku import initializers
     21 from haiku import nets

/usr/local/lib/python3.6/dist-packages/haiku/experimental.py in <module>()
     22 from haiku._src.base import custom_getter
     23 from haiku._src.base import ParamContext
---> 24 from haiku._src.dot import to_dot
     25 from haiku._src.lift import lift
     26 from haiku._src.module import profiler_name_scopes

/usr/local/lib/python3.6/dist-packages/haiku/_src/dot.py in <module>()
     23 
     24 from haiku._src import data_structures
---> 25 from haiku._src import module
     26 from haiku._src import utils
     27 import jax

/usr/local/lib/python3.6/dist-packages/haiku/_src/module.py in <module>()
     26 from haiku._src import base
     27 from haiku._src import data_structures
---> 28 from haiku._src import named_call
     29 from haiku._src import utils
     30 import jax.numpy as jnp

/usr/local/lib/python3.6/dist-packages/haiku/_src/named_call.py in <module>()
     29 
     30 # Registering named call as a primitive
---> 31 named_call_p = core.CallPrimitive('named_call')
     32 # named_call is implemented as a plain core.call and only diverges
     33 # under compilation (see named_call_translation_rule)

AttributeError: module 'jax.core' has no attribute 'CallPrimitive'

I believe that's because Haiku now requires jax>=0.1.71, while the version by default on Colab is jax==0.1.69. CallPrimitive was introduced in jax 0.1.71.
https://github.com/google/jax/blob/1545a29e6d69a7b3c7fdf9a49b38004759a9fbfa/jax/core.py#L1106-L1115

To reproduce (inside a Colab):

import jax
print(jax.__version__)  # 0.1.69

!pip install -q git+https://github.com/deepmind/dm-haiku
import haiku as hk

Run !pip install -q --upgrade jax jaxlib first in your Colab to fix this issue.

Ensure float32 inputs imply float32 outputs when jax_enable_x64=1

import os
os.environ["JAX_ENABLE_X64"] = "1"

import jax
import haiku as hk
import numpy as np

@hk.transform
def f(x):
  return hk.Linear(4)(x)

f32_data = np.zeros((4, 8), dtype=np.float32)

p = f.init(jax.random.PRNGKey(428), f32_data)
print(jax.tree_map(lambda t: t.dtype, p))
f32_params = jax.tree_map(lambda t: t.astype(np.float32), p)
print(f.apply(f32_params, f32_data).dtype)

Prints:

frozendict({
  'linear': frozendict({'b': dtype('float32'), 'w': dtype('float64')}),
})
dtype('float32')

Hopefully the bfloat16 compatibility work means we get most of this for free and that we only need to port the initializers.

Iterating through hk modules

Let's say I want to iterate through all modules inside an hk model and replace all hn.Linears with my own custom Module or monkey-patch some of their properties. Does haiku currently support something along these lines?

New interface for spectral normalization

I noticed earlier today that Haiku has SpectralNormalization -- very cool!

I'm interested in implementing an improved version, which does a much better job estimating the norm for convolutional layers and should converge to the correct answer for any linear operator. The trick is to use auto-diff to calculate the transpose of the linear operator. In contrast, the current implementation is only accurate for dense matrices.

Here's my implementation in pure JAX: https://nbviewer.jupyter.org/gist/shoyer/fa9a29fd0880e2e033d7696585978bfc

My question: how can I implement this in Haiku?

  • It feel like the right way to write this would be as a Module that takes another Module (or function) as an argument, but I don't know of any existing prior art for that. Would does that make sense to you?
  • How do I call jax.vjp on Module? I'm guessing (though to be honest I haven't checked yet) that normal JAX function would break, given the way that Haiku adds mutable state.

extra dimension for jacobian of model, or jacfwd on model.apply

Hi:

I'm computing the Jacobian of my model on every step. The input is a z-dimensional vector and model's output is (bz,ch,h,w), it's a decoder. I expected the output of the Jacobian to be (bz,ch,h,w,z) and it is when i compute the jacobian of a row in my batch. However, when I stick in the whole batch, the dimensions become: (bz,ch,h,w,bz,z). Why the extra bz?

I checked the Jacobian output values for the i^th row, and (i,ch,h,w,j,z) is zero for all j != i. And it has the correct values when j = i. I can still use this; however, I have to take the extra step of removing the zero outputs.

Here's my code for the Jacobian:

decoder_jac = hk.transform(lambda z: jacfwd(Decoder())(z))

Name change or adjustment?

Hi! I'm one of the developers of Haiku and a board member of Haiku, Inc.

We have trademark on the name Haiku around our open-source operating system, as well as a registered logo mark on our Haiku logo.

Haiku has been in development for nearly 20 years now, and usage of the Haiku name for other open source software projects could create confusion.

Please feel free to reach out to Haiku, Inc. if you have any questions.

https://haiku-os.org
https://haiku-inc.org

Haiku LSTM vs Torch LSTM

I'm trying to figure out the differences between pytorch's LSTM and Haiku's LSTM. In Haiku, the LSTM expects the input to be a rank 1 or rank 2 tensor. However, in pytorch, the input is expected to be a rank 3 tensor. I was hoping to get some clarification on how one would expect to get the same behavior in Haiku's lstms compared to torch's lstms.

Some additional information -- if this is the case, it seems like the LSTM in Haiku is significantly slower than torch. It's very likely that I've made a mistake somewhere here.

Also, would it be correct in saying that haiku's static_unroll + lstm is the same as pytorch's lstm? In which case, isn't Haiku's lstm technically an LSTMcell?

Merge FlatMappings with DeviceArray as values raises `AttributeError: 'DeviceArray' object has no attribute 'items'`

Hi, I'm trying to do meta learning with some slow and fast weights. From the params returned when calling .init I obtain the slow and fast weights like this

params = f.init(...)
fast_weights = params["fast_weights"]

And just before calling the .apply function I want to merge them (the fast weights will be modified). My first attempt was to use the hk.data_structures.merge method like this

params = hk.data_structures.merge(params, fast_weights)
output = f.apply(params, rng, inputs)

But this raises the exception AttributeError: 'DeviceArray' object has no attribute 'items'. I was wondering if this behaviour is wanted or if I should use a different approach for what I want to do.

Thanks!

Improve `params_dict()` support

module.params_dict() can behave in surprising ways:

def f(x):
  mod = hk.Linear(8)
  print(mod.params_dict())  # empty during init, full during apply
  sequential = hk.Sequential([mod])
  print(sequential.params_dict())  # always empty
  out = sequential(x)
  print(sequential.params_dict())  # no longer empty
  return out

net = hk.transform(f)
p = net.init(jax.random.PRNGKey(428), np.zeros((2, 3)))
net.apply(p, np.zeros((2, 3)))

Prints:

{}
{}
{...}
{...}
{}
{...}

We should clean up & clearly define the desired semantics of params_dict().

Remove optix, add optax

The JAX team is going to remove optix from the library now that optax exists, can I open a PR changing all imports from jax.experimental.optix to optax?

hk.Embed's embedding_matrix argument can't be supplied a np.ndarray

Seems like unintended behavior.

/usr/local/lib/python3.6/dist-packages/haiku/_src/embed.py in __init__(self, vocab_size, embed_dim, embedding_matrix, w_init, lookup_style, name)
     73     """
     74     super(Embed, self).__init__(name=name)
---> 75     if not embedding_matrix and not (vocab_size and embed_dim):
     76       raise ValueError(
     77           "hk.Embed must be supplied either with an initial `embedding_matrix` "

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Importing ABC directly from collections will be removed in Python 3.10

Use collections.abc instead. Also collections.abc.KeysView is used in data_structures.py below but the import statement at top is just collections. Accessing collections.abc with just import collections is not supported.

haiku/_src/utils.py
121:      not isinstance(element, collections.Sequence)):

haiku/_src/data_structures.py
80:class KeysOnlyKeysView(collections.abc.KeysView):

haiku/_src/layer_norm.py
79:    elif (isinstance(axis, collections.Iterable) and

haiku/_src/stateful.py
48:    if isinstance(v, collections.Mapping):

He initialization

The default initialization for linear and convolutional modules seems to be Glorot initialization, but for the commonly used ReLU activation function He initialization is superior, while only requiring a quick change to the stddev definition, should we implement better defaults?
I know that there are many initialization schemes, I only suggest it as it would't be computationally expensive and would also be only a minor code change.

deepcopy broken for new FlatMapping

Hi there,

I noticed there have been some changes to FlatMapping.

I can imagine that you don't really see the need for deepcopying a FlatMapping as it's supposed to be immutable. But just so you know, deepcopy doesn't work anymore:

from haiku._src.data_structures import FlatMapping

m = FlatMapping.from_mapping({'foo': 'bar'})
deepcopy(m)  # raises TypeError: can't pickle jaxlib.pytree.PyTreeDef objects

P.S. I stumbled upon this because I'm deriving a subclass with limited mutability from FlatMapping (whose leaves are all DeviceArrays). I'm using deepcopy for my target-network / behavior-policy weights.

I added a custom implementation to my derived class:

class Foo(FlatMapping):
    ...

    def __deepcopy__(self, memo):
        leaves, treedef = self.flatten()
        return self.__class__((deepcopy(leaves), treedef))

Also.. thanks for the speed-up in FlatMapping!

AttributeError: Can't pickle local object 'transform_with_state.<locals>.init_fn'

for concurrent.futures.ProcessPoolExecutor or multiprocessing.Pool, it's necessary to pickle objects to send em to other processes to run. That doesn't work with transformed haiku stuff:

import pickle

import haiku as hk

def forward(x): return hk.Linear(x)
stateful = hk.transform_with_state(forward)
pickle.dumps(stateful)
AttributeError: Can't pickle local object 'transform_with_state.<locals>.init_fn'
stateless = hk.transform(forward)
pickle.dumps(stateless)
AttributeError: Can't pickle local object 'without_state.<locals>.init_fn'

How could we pickle haiku models?

Should we take a different approach?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.