patrick-kidger / quax Goto Github PK
View Code? Open in Web Editor NEWMultiple dispatch over abstract array types in JAX.
License: Apache License 2.0
Multiple dispatch over abstract array types in JAX.
License: Apache License 2.0
Hi! I'm running into a few issues with quax
+ jnp.zeros_like
.
The built-in quax.zeros.Zeros
appears to leak into the jnp.zeros_like
and its related functions.
>>> import jax.numpy as jnp
>>> from quax import quaxify
>>> x = jnp.array([1, 2, 3], dtype=jnp.float64)
>>> quaxify(jnp.zeros_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))
>>> quaxify(jnp.empty_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))
I think quaxify(jnp.zeros_like)(jax.Array)
should output a jax.Array
.
More importantly (to me), when I make a custom quax.ArrayValue
subclass it doesn't override the behavior of Zero
with these functions.
>>> class MyArray(quax.ArrayValue): ...
>>> ... [override all related primitives]
>>> y = MyArray(x)
>>> quaxify(jnp.zeros_like)(y)
Zero(_shape=(3,), _dtype=dtype('float64'))
Just a suggestion to add a small section to the docs about best practices for quaxifying jitted and vmapped functions. Thanks!
In #10 (comment)_ we discussed a function wrap_arrayish_value_into_tracer_like
to handle cases like
def funct(x, y):
z = wrap_arrayish_value_into_tracer_like(Value(...), x)
return x + y + z
In GalacticDynamics/galax#187 I'm encountering the lack of this functionality in the _potential_energy
methods, where the Parameters of the potential do not pass through the same quax boundary.
def _potential_energy(self, xyz, t, _G):
m = self.some_param(t)
return G * m / ...
In our case some_param
is often a callable object with a .value
attribute. In testing this self.some_param.value
is not wrapped into a tracer when self
is wrapped. Would recursively wrapping into the PyTree be able to fix this problem?
Based on quax
I wrote array-api-jax-compat for use with jax-quantity. Are you interested to upstream array-api-jax-compat as a submodule, e.g. quax.array_api
? With the submodule users won't have to quaxif
any function in the array-api themselves, just import quax.array_api as xp
.
Hi! I love the possiblities of quax and would like to use it for a unit system in my acustical wave simulation. For this, it is necessary to register a function for jax.lax.while_p, since the simulation runs for many thousands of steps in a while loop. I was wondering if you could give some tips for the implementation. An MWE of my current attempt looks something like this:
import jax
import jax.numpy as jnp
import quax
import jax.core as core
class ArrayWrapper(quax.ArrayValue):
array: jax.Array
def aval(self):
return core.ShapedArray(self.array.shape, self.array.dtype)
def materialise(self) -> jax.Array:
raise ValueError("Refusing to materialise")
@quax.register(jax.lax.while_p)
def _(*args, cond_nconsts: int, cond_jaxpr, body_nconsts: int, body_jaxpr):
new_args = [a.array if isinstance(a, ArrayWrapper) else a for a in args]
is_quaxed = [isinstance(a, ArrayWrapper) for a in args]
out = jax.lax.while_p.bind(
*new_args,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr,
)
return [
ArrayWrapper(a) if quaxed else a
for a, quaxed in zip(out, is_quaxed[cond_nconsts+body_nconsts:])
]
def body_fn(a: jax.Array):
return a + 1
def cond_fn(a: jax.Array):
return a[1] < 10
def loop_fn(a: jax.Array):
res = jax.lax.while_loop(
body_fun=body_fn,
cond_fun=cond_fn,
init_val=a,
)
return res
a = ArrayWrapper(jnp.arange(10))
a = jax.jit(quax.quaxify(loop_fn))(a)
print(a.array)
Even though this code executes, it is far from optimal. Since the body_fn and cond_fn are no longer quaxed, all advantages of the unit system are lost. Especially, the expected bahvior should be that the code above raises an Exception as the primitives for adding, less_than, etc. are not registered. But, the code runs since cond_fn and body_fn are no longer quaxed.
I don't know how one could integrate the quaxed functions into the XLA WhileOp primitive. Do you have any insights on how to achieve this?
Many thanks
Yannik
Hi! Really enjoying quax
. I've been working to get galax
potentials quaxified (most relevantly in GalacticDynamics/galax#187) and ran into a compatibility issue with jaxtyping
's runtime type checking.
If a module has runtime type checking turned on, e.g. install_import_hook("galax.potential", "beartype.beartype")
then quaxed
functions don't pass objects through correctly. As an example from GalacticDynamics/galax#187
>>> from jax_quantity import Quantity
>>> import galax.potential as gp
>>> import galax.units as gu
>>> pot = gp.KeplerPotential(m=Quantity(1e12, "Msun"), units=gu.galactic)
>>> pot._potential_energy(Quantity([1.0, 0, 0], "kpc"), Quantity(0, "Myr"), pot._G)
TypeCheckError: Type-check error whilst checking the parameters of _potential_energy.
The problem arose whilst typechecking argument 'q'.
Called with arguments: {'self': KeplerPotential(...), 'q': f64[3], 't': i64[], '_G': f64[]}
Parameter annotations: (self, q: Shaped[Quantity, '*batch 3'], t: Union[Shaped[Quantity, '*#batch '], Shaped[Quantity, '*#batch ']], /, _G: Float[Quantity, '']).
In GalacticDynamics/unxt#4 I'm trying to get jax.grad
to work on functions that accept Quantity
arguments, and have run into some difficulties.
The following doesn't work,
import jax
import jax.numpy as jnp
from jax_quantity import Quantity
jax.config.update("jax_enable_x64", True)
x = jnp.array([1, 2, 3], dtype=jax_xp.float64)
q = Quantity(x, unit="m")
def func(q) -> Quantity:
return 5 * q**2 + Quantity(1.0, unit="m2")
quaxify(jax.grad)(func)(q[0])
returning an error TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m2")).
This error was expected since grad
checks for scalar outputs (with jax._src.api._check_scalar
). The underlying issue appeared to be that _check_scalar
calls concrete_aval
, which errors on Quantity
. quax
-compatible classes have an aval()
method so I hooked that up to a handler and registered it into pytype_aval_mappings
jax._src.core.pytype_aval_mappings[Quantity] = lambda q: q.aval()
quaxify(jax.grad)(func)(q[0])
While this gets a few lines further in grad
, unfortunately this causes a disagreement between pytree structures
with the error
TypeError: Tree structure of cotangent input PyTreeDef(*), does not match structure of primal output PyTreeDef(CustomNode(Quantity[('value',), ('unit',), (Unit("m2"),)], [*])).
.
I haven't figured out how to fix this issue. Any suggestions would be appreciated!
p.s. @dfm has figured out how to do grad
on Quantity
in jpu
by shunting the units to aux
data and re-assembling after. This solution works well, but it's a solution unique to Quantity, requiring a custom grad
function. I was hoping to get this working with quaxify
in a way that didn't require in https://github.com/GalacticDynamics/array-api-jax-compat dispatching using plum
to library-specific grad
implementations (especially since it's not obvious on what to dispatch to map func
to the Quantity
implementation).
Hi, in the Unitful
example with the numerical values you are using, the mass unit should be kg
instead of g
g=5.9722e24/6.3781e6**2*6.67430e-11=9.798
Refs:
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.