Comments (10)
def save(ckpt_dir: str, state) -> None:
with open(os.path.join(ckpt_dir, "arrays.npy"), "wb") as f:
for x in jax.tree_leaves(state):
np.save(f, x, allow_pickle=False)
tree_struct = jax.tree_map(lambda t: 0, state)
with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
pickle.dump(tree_struct, f)
def restore(ckpt_dir):
with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
tree_struct = pickle.load(f)
leaves, treedef = jax.tree_flatten(tree_struct)
with open(os.path.join(ckpt_dir, "arrays.npy"), "rb") as f:
flat_state = [np.load(f) for _ in leaves]
return jax.tree_unflatten(treedef, flat_state)
typed out in the comment-box so may need some adjustment to actually run.
from dm-haiku.
Hey, we're intentionally un-opinionated here. I will note:
- Haiku params (and network state) are transparent dictionaries of JAX
jnp.ndarray
s. jnp.ndarray
converts tonp.ndarray
, so when using non-bfloat16
types, anything that works to save NumPy will work here.
There are a few options we've seen work well:
- Directly pickle the params dict. Upside: it just works, downside: may not be totally efficient, and has usual pickle caveats.
- Use
np.save
ornp.savez
to store thendarray
s in a flat format, and save the tree structure via either pickle or a stable serialized format (protobuf, json, yaml, you name it.)
I'll look into extending either the Transformer or ResNet example with checkpointing, so we have a concrete piece of code that we can point people to as an example.
I'll leave this bug open conditioned on that - hope this helps!
from dm-haiku.
Hey, nice job! That's correct, we'd like to replace it with the FlatMapping class below it.
The bit is ready to be flipped, we'll look to flip it soon if we can.
from dm-haiku.
Quick update - it turns out the bit is not ready to be flipped, there are a couple edge cases that need to be fixed. We don't really have the time to look into this for now, so don't expect it to flip in the near future.
from dm-haiku.
Thanks for the help! @trevorcai
Your advice helped me a lot!
I'm planning to try serialization via protobuf over gRPC communication.
and for the checkpointing, I'll wait for your examples :)
from dm-haiku.
https://github.com/chris-chris/haiku-scalable-example/pull/1/files
Thanks! @trevorcai
I made encoder & decoder for the haiku model weights and trajectories for gRPC protobuf message.
I noticed that you used frozendict
for the data structure of model weights.
And there was a comment on this data type.
Is this data type going to be deprecated?
# TODO(lenamartens) Deprecate type
https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/data_structures.py#L80
from dm-haiku.
I've been using https://github.com/cloudpipe/cloudpickle/ which seems to be working well.
from dm-haiku.
What library is recommended for directly serializing the params dict? What are the caveats? I think adding these to the docs will be a nice addition, or at least links to other good docs on serialization in Python.
from dm-haiku.
If you use HAIKU_FLATMAPPING=0
, then Haiku checkpointing is as simple as serializing dict
s of np.ndarray
s; any solution that works for that will work for Haiku.
The transformer example is a simple demonstration of pickle-ing the entire state:
https://github.com/deepmind/dm-haiku/blob/main/examples/transformer/train.py#L168-L218
Two years on, the vast majority of people at DeepMind use np.save
to store the np.ndarrays
in a flat format, and save the tree structure separately through pickle
or a specialized internal format (that I don't know the details of because I use pickle
).
from dm-haiku.
@trevorcai Is there an example of saving the tree structure and then loading the np.ndarrays back into it?
from dm-haiku.
Related Issues (20)
- Is there a way to load parameters from Flax model? HOT 2
- Support model examples HOT 7
- Change to jax.interpreters.xla for JAX==0.4.14 HOT 3
- Warning: hk.LayerNorm when used in transformer decoder causes violation of autoregressive property HOT 1
- Reservoir Computing with Haiku
- Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? HOT 2
- Please publish requirements.txt fix to pip
- How to use `apply` with additional parameters? HOT 1
- hk.Conv2DTranspose takes FOREVER to initialize and compile HOT 1
- 0.4.16 timeline HOT 2
- How to export haiku network parameters into Pytorch network?
- Modules got silently "reused" with `hk.vmap` HOT 2
- Wrong gradients in a Haiku network
- Direct Feedback Alignment
- Issue with wheels including docs and examples folder
- `haiku.experimental.flax` is not part of newest pip release HOT 1
- Train multiple hk.nets.MLP with one optimizer HOT 2
- TypeError: 'type' object is not subscriptable HOT 4
- Wrapping the ```init``` function inside ```jax.jit``` HOT 1
- Consider make flax an optional dependency HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dm-haiku.