Giter Club home page Giter Club logo

Comments (5)

trevorcai avatar trevorcai commented on July 17, 2024 1

I was thinking the same thing; we can view the returned value as an overlay over the original values.
There are a couple of annoying things here to be aware of:

  • For state, updated states will still modify the original state member of StatePair to be the _foil_cse'd one. This is pretty annoying, though not the worst thing ever.
  • The _foil_cse HLO does not get pruned when not used - we're left with a huge number of vestigial constant and rng HLO ops. Blake from XLA team speculates that this is because the ops could be side-effecting.

As a side note, it's not clear to me whether _foil_cse-ing params (which don't need to be _foil_cse'd, as they're constant!) is even a good idea. This is sufficiently murky that I think Haiku shouldn't take a stance on this.

from dm-haiku.

tomhennigan avatar tomhennigan commented on July 17, 2024

I like 1, we can definitely be smarter wrt params and state.

I think a general solution would be to return only updated values for both params/state and then merge into the parent frame. I think we can do this by stashing the intial values (we only care about identity) that are inputs to the stateful_fun and only returning ones that have a different identity.

def only_new_values(original, updated):
  output = defaultdict(dict)
  for mod_name, mod_state in updated.items():
    for name, value in updated.items():
      if name in original[mod_name] and original[mod_name][name] is not value:
        output[mod_name][name] = value
  return output

def stateful_fun(*a, **k):
  hk_state = k.pop(..)
  orig_params, orig_state = copy_params_state(hk_state)

  with use_and_update(hk_state):
    out = f(*a, **k)

  # Only return updated values.
  params = only_new_values(orig_params, hk_state.params)
  state = only_new_values(orig_state, hk_state.state)
  return out, (params, state)

I think I agree re PRNG keys too, but would want to triple check the details.

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024

My changes in google/jax#2391 to the XLA CSE mechanism change the semantics of this bug.
Now, we're incurring unnecessary serialization (because parameters are now an output of the remat layer), rather than unnecessary layers of _foil_cse.

The underlying cause remains the same (unnecessarily feeding and updating the Haiku state).

from dm-haiku.

tomhennigan avatar tomhennigan commented on July 17, 2024

I'll take a stab at not updating state if it doesn't change.

from dm-haiku.

trevorcai avatar trevorcai commented on July 17, 2024

As a result of google/jax#2391, it looks like the problems get optimized away; we're not even incurring unnecessary serialization right now.

These changes may still be good to make defensively.

from dm-haiku.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.