Giter Club home page Giter Club logo

Comments (10)

trevorcai avatar trevorcai commented on July 17, 2024 12
def save(ckpt_dir: str, state) -> None:
 with open(os.path.join(ckpt_dir, "arrays.npy"), "wb") as f:
   for x in jax.tree_leaves(state):
     np.save(f, x, allow_pickle=False)

 tree_struct = jax.tree_map(lambda t: 0, state)
 with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
   pickle.dump(tree_struct, f)

def restore(ckpt_dir):
 with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
   tree_struct = pickle.load(f)
 
 leaves, treedef = jax.tree_flatten(tree_struct)
 with open(os.path.join(ckpt_dir, "arrays.npy"), "rb") as f:
   flat_state = [np.load(f) for _ in leaves]

 return jax.tree_unflatten(treedef, flat_state)

typed out in the comment-box so may need some adjustment to actually run.

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024 8

Hey, we're intentionally un-opinionated here. I will note:

  1. Haiku params (and network state) are transparent dictionaries of JAX jnp.ndarrays.
  2. jnp.ndarray converts to np.ndarray, so when using non-bfloat16 types, anything that works to save NumPy will work here.

There are a few options we've seen work well:

  • Directly pickle the params dict. Upside: it just works, downside: may not be totally efficient, and has usual pickle caveats.
  • Use np.save or np.savez to store the ndarrays in a flat format, and save the tree structure via either pickle or a stable serialized format (protobuf, json, yaml, you name it.)

I'll look into extending either the Transformer or ResNet example with checkpointing, so we have a concrete piece of code that we can point people to as an example.

I'll leave this bug open conditioned on that - hope this helps!

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024 1

Hey, nice job! That's correct, we'd like to replace it with the FlatMapping class below it.
The bit is ready to be flipped, we'll look to flip it soon if we can.

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024 1

Quick update - it turns out the bit is not ready to be flipped, there are a couple edge cases that need to be fixed. We don't really have the time to look into this for now, so don't expect it to flip in the near future.

from dm-haiku.

chris-chris avatar chris-chris commented on July 17, 2024

Thanks for the help! @trevorcai

Your advice helped me a lot!

I'm planning to try serialization via protobuf over gRPC communication.
and for the checkpointing, I'll wait for your examples :)

from dm-haiku.

chris-chris avatar chris-chris commented on July 17, 2024

https://github.com/chris-chris/haiku-scalable-example/pull/1/files

Thanks! @trevorcai

I made encoder & decoder for the haiku model weights and trajectories for gRPC protobuf message.
I noticed that you used frozendict for the data structure of model weights.
And there was a comment on this data type.

Is this data type going to be deprecated?

# TODO(lenamartens) Deprecate type

https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/data_structures.py#L80

from dm-haiku.

asmith26 avatar asmith26 commented on July 17, 2024

I've been using https://github.com/cloudpipe/cloudpickle/ which seems to be working well.

from dm-haiku.

NightMachinery avatar NightMachinery commented on July 17, 2024

What library is recommended for directly serializing the params dict? What are the caveats? I think adding these to the docs will be a nice addition, or at least links to other good docs on serialization in Python.

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024

If you use HAIKU_FLATMAPPING=0, then Haiku checkpointing is as simple as serializing dicts of np.ndarrays; any solution that works for that will work for Haiku.

The transformer example is a simple demonstration of pickle-ing the entire state:
https://github.com/deepmind/dm-haiku/blob/main/examples/transformer/train.py#L168-L218

Two years on, the vast majority of people at DeepMind use np.save to store the np.ndarrays in a flat format, and save the tree structure separately through pickle or a specialized internal format (that I don't know the details of because I use pickle).

from dm-haiku.

NightMachinery avatar NightMachinery commented on July 17, 2024

@trevorcai Is there an example of saving the tree structure and then loading the np.ndarrays back into it?

from dm-haiku.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.