Giter Club home page Giter Club logo

Comments (9)

nstarman avatar nstarman commented on June 14, 2024 1

Thanks, I was actually not aware of jax.experimental.array_api but that will make my life significantly easier!
Each function in https://github.com/GalacticDynamics/array-api-jax-compat/ will then be a quax-ified / plum.dispatcher wrapper around jax.experimental.array_api + miscellaneous quax-ified functions like jacfwd and grad.

The goal is to support Astropy-like Quantities in JAX https://github.com/GalacticDynamics/jax-quantity. In that repo I've gotten most of the Astropy -> quax -> jax bridges completed.

from quax.

patrick-kidger avatar patrick-kidger commented on June 14, 2024

So this is intentional but definitely questionable.

The reason for this is this rule for Zero, which will bind against any ArrayValue (which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ish Zero as an input to the rule.

I think changing this may be impossible without also having jnp.zeros(...) also return an Array. Both use identical primitive binds inside JAX: loosely speaking they're implemented as def zeros(shape, dtype): return broadcast(0, shape, dtype) and def zeros_like(array): return broadcast(0, array.shape, array.dtype).

I'm not completely sure how this might be tackled; I'd welcome any thoughts.

from quax.

nstarman avatar nstarman commented on June 14, 2024

Thanks for the fast response.

The reason for this is this rule for Zero, which will bind against any ArrayValue (which is the fill value for the array). This is unusual in that we don't need an instance of the corresponding array-ish Zero as an input to the rule.

I had registered a specific dispatch for the primitive lax.broadcast_in_dim_p. However, I don't think this is ever called with MyArray, as the Zero is instantiated, then fed into the broadcast.

I'm not completely sure how this might be tackled; I'd welcome any thoughts.

Short of monkey-patching jax, IDK as well. I think I might just use plum on a custom zeros_like function to dispatch to MyArray (and Zero). That should work, if a little less elegantly than quax's internal use of plum.

from quax.

patrick-kidger avatar patrick-kidger commented on June 14, 2024

as the Zero is instantiated, then fed into the broadcast.

Actually, it's a JAX array that's instantiated! But indeed, it's not a custom MyArray either way.

FWIW, I'm contemplating simply removing that broadcast dispatch rule from Quax. It's always been a bit magic that we produce a Zero without having a Zero input. It's also not totally reliable: something like quax.quaxify(jnp.zeros_like)(jnp.array(0)) doesn't trigger it, as that doesn't involve a broadcast. A zeros/zero_like will then unconditionally produce a normal JAX array.

I'm not sure how much that helps you of course, but it's something.

from quax.

nstarman avatar nstarman commented on June 14, 2024

Still very much a work in progress, but check out https://github.com/GalacticDynamics/array-api-jax-compat/, where I'm leveraging quax to make a bridge to the Array API that also works with JAX array-ish objects.

from quax.

patrick-kidger avatar patrick-kidger commented on June 14, 2024

Oh, this looks neat! Can you tell me a bit more about this?

I'm noticing a simliarity to jax.experimental.array_api, I assume you're buliding off of that as well?

from quax.

nstarman avatar nstarman commented on June 14, 2024

array-api-jax-compat is now mostly a quax wrapper around jax.experimental.array_api.

from quax.

patrick-kidger avatar patrick-kidger commented on June 14, 2024

Awesome!

from quax.

nstarman avatar nstarman commented on June 14, 2024

I think this is now resolved!

from quax.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.