Giter Club home page Giter Club logo

quax's Introduction

Quax

JAX + multiple dispatch + custom array-ish objects

For example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul (W + AB)v into the more-efficient Wv + ABv.

Applications include:

  • LoRA weight matrices
  • symbolic zeros
  • arrays with named dimensions
  • structured (e.g. tridiagonal) matrices
  • sparse arrays
  • quantised arrays
  • arrays with physical units attached
  • etc! (See the built-in quax.examples library for most of the above!)

This works via a custom JAX transform. Take an existing JAX program, wrap it in a quax.quaxify, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!

(Just like how jax.vmap takes a program, but reinterprets each operation as its batched version, so to will quax.quaxify take a program and reinterpret each operation according to what array-ish types are passed.)

Installation

pip install quax

Documentation

Available at https://docs.kidger.site/quax.

Example: LoRA

This example demonstrates everything you need to use the built-in quax.examples.lora library.

import equinox as eqx
import jax.random as jr
import quax
import quax.examples.lora as lora

#
# Start off with any JAX program: here, the forward pass through a linear layer.
#

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
linear = eqx.nn.Linear(10, 12, key=key1)
vector = jr.normal(key2, (10,))

def run(model, x):
  return model(x)

run(linear, vector)  # can call this as normal

#
# Now let's Lora-ify it.
#

# Step 1: make the weight be a LoraArray.
lora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)
lora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)
# Step 2: quaxify and call the original function. The transform will call the
# original function, whilst looking up any multiple dispatch rules registered.
# (In this case for doing matmuls against LoraArrays.)
quax.quaxify(run)(lora_linear, vector)
# Appendix: Quax includes a helper to automatically apply Step 1 to all
# `eqx.nn.Linear` layers in a model.
lora_linear = lora.loraify(linear, rank=2, key=key3)

Work in progress!

Right now, the following are not supported:

  • Control flow primitives (e.g. jax.lax.cond).
  • jax.custom_vjp

It should be fairly straightforward to add support for these; open an issue or pull request.

See also: other libraries in the JAX ecosystem

Equinox: neural networks.

jaxtyping: type annotations for shape/dtype of arrays.

Optax: first-order gradient (SGD, Adam, ...) optimisers.

Diffrax: numerical differential equation solvers.

Optimistix: root finding, minimisation, fixed points, and least squares.

Lineax: linear solvers.

BlackJAX: probabilistic+Bayesian sampling.

Orbax: checkpointing (async/multi-host/multi-device).

sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.

Eqxvision: computer vision models.

Levanter: scalable+reliable training of foundation models (e.g. LLMs).

PySR: symbolic regression. (Non-JAX honourable mention!)

Acknowledgements

Significantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and jax.experimental.sparse.

quax's People

Contributors

nstarman avatar patrick-kidger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

nstarman ymahlau

quax's Issues

How to implement jax.lax.while with quax

Hi! I love the possiblities of quax and would like to use it for a unit system in my acustical wave simulation. For this, it is necessary to register a function for jax.lax.while_p, since the simulation runs for many thousands of steps in a while loop. I was wondering if you could give some tips for the implementation. An MWE of my current attempt looks something like this:

import jax
import jax.numpy as jnp
import quax
import jax.core as core

class ArrayWrapper(quax.ArrayValue):
    array: jax.Array

    def aval(self):
        return core.ShapedArray(self.array.shape, self.array.dtype)

    def materialise(self) -> jax.Array:
        raise ValueError("Refusing to materialise")
        
@quax.register(jax.lax.while_p)
def _(*args, cond_nconsts: int, cond_jaxpr, body_nconsts: int, body_jaxpr):
    new_args = [a.array if isinstance(a, ArrayWrapper) else a for a in args]
    is_quaxed = [isinstance(a, ArrayWrapper) for a in args]
    out = jax.lax.while_p.bind(
        *new_args, 
        cond_nconsts=cond_nconsts,
        cond_jaxpr=cond_jaxpr,
        body_nconsts=body_nconsts,
        body_jaxpr=body_jaxpr,
    )
    return [
        ArrayWrapper(a) if quaxed else a
        for a, quaxed in zip(out, is_quaxed[cond_nconsts+body_nconsts:])
    ]

def body_fn(a: jax.Array):
    return a + 1

def cond_fn(a: jax.Array):
    return a[1] < 10

def loop_fn(a: jax.Array):
    res = jax.lax.while_loop(
        body_fun=body_fn,
        cond_fun=cond_fn,
        init_val=a,
    )
    return res

a = ArrayWrapper(jnp.arange(10))
a = jax.jit(quax.quaxify(loop_fn))(a)
print(a.array)

Even though this code executes, it is far from optimal. Since the body_fn and cond_fn are no longer quaxed, all advantages of the unit system are lost. Especially, the expected bahvior should be that the code above raises an Exception as the primitives for adding, less_than, etc. are not registered. But, the code runs since cond_fn and body_fn are no longer quaxed.

I don't know how one could integrate the quaxed functions into the XLA WhileOp primitive. Do you have any insights on how to achieve this?

Many thanks
Yannik

`quaxify` on a `jax.grad`

In GalacticDynamics/unxt#4 I'm trying to get jax.grad to work on functions that accept Quantity arguments, and have run into some difficulties.

The following doesn't work,

import jax
import jax.numpy as jnp
from jax_quantity import Quantity
jax.config.update("jax_enable_x64", True)

x = jnp.array([1, 2, 3], dtype=jax_xp.float64)
q = Quantity(x, unit="m")

def func(q) -> Quantity:
    return 5 * q**2 + Quantity(1.0, unit="m2")

quaxify(jax.grad)(func)(q[0])

returning an error TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m2")). This error was expected since grad checks for scalar outputs (with jax._src.api._check_scalar). The underlying issue appeared to be that _check_scalar calls concrete_aval, which errors on Quantity. quax-compatible classes have an aval() method so I hooked that up to a handler and registered it into pytype_aval_mappings

jax._src.core.pytype_aval_mappings[Quantity] = lambda q: q.aval()

quaxify(jax.grad)(func)(q[0])

While this gets a few lines further in grad, unfortunately this causes a disagreement between pytree structures
with the error
TypeError: Tree structure of cotangent input PyTreeDef(*), does not match structure of primal output PyTreeDef(CustomNode(Quantity[('value',), ('unit',), (Unit("m2"),)], [*]))..
I haven't figured out how to fix this issue. Any suggestions would be appreciated!

p.s. @dfm has figured out how to do grad on Quantity in jpu by shunting the units to aux data and re-assembling after. This solution works well, but it's a solution unique to Quantity, requiring a custom grad function. I was hoping to get this working with quaxify in a way that didn't require in https://github.com/GalacticDynamics/array-api-jax-compat dispatching using plum to library-specific grad implementations (especially since it's not obvious on what to dispatch to map func to the Quantity implementation).

Compatibility with `jaxtyping`

Hi! Really enjoying quax. I've been working to get galax potentials quaxified (most relevantly in GalacticDynamics/galax#187) and ran into a compatibility issue with jaxtyping's runtime type checking.

If a module has runtime type checking turned on, e.g. install_import_hook("galax.potential", "beartype.beartype") then quaxed functions don't pass objects through correctly. As an example from GalacticDynamics/galax#187

>>> from jax_quantity import Quantity
>>> import galax.potential as gp
>>> import galax.units as gu
>>> pot = gp.KeplerPotential(m=Quantity(1e12, "Msun"), units=gu.galactic)
>>> pot._potential_energy(Quantity([1.0, 0, 0], "kpc"), Quantity(0, "Myr"), pot._G)
TypeCheckError: Type-check error whilst checking the parameters of _potential_energy.
The problem arose whilst typechecking argument 'q'.
Called with arguments: {'self': KeplerPotential(...), 'q': f64[3], 't': i64[], '_G': f64[]}
Parameter annotations: (self, q: Shaped[Quantity, '*batch 3'], t: Union[Shaped[Quantity, '*#batch '], Shaped[Quantity, '*#batch ']], /, _G: Float[Quantity, '']).

Support for `zeros_like` and related

Hi! I'm running into a few issues with quax + jnp.zeros_like.
The built-in quax.zeros.Zeros appears to leak into the jnp.zeros_like and its related functions.

>>> import jax.numpy as jnp
>>> from quax import quaxify
>>> x = jnp.array([1, 2, 3], dtype=jnp.float64)
>>> quaxify(jnp.zeros_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))
>>> quaxify(jnp.empty_like)(x)
Zero(_shape=(3,), _dtype=dtype('float64'))

I think quaxify(jnp.zeros_like)(jax.Array) should output a jax.Array.

More importantly (to me), when I make a custom quax.ArrayValue subclass it doesn't override the behavior of Zero with these functions.

>>> class MyArray(quax.ArrayValue): ...
>>> ... [override all related primitives]

>>> y = MyArray(x)
>>> quaxify(jnp.zeros_like)(y)
Zero(_shape=(3,), _dtype=dtype('float64'))

jit and vamp best practices

Just a suggestion to add a small section to the docs about best practices for quaxifying jitted and vmapped functions. Thanks!

quax boundary recursive wrapping of a pytree

In #10 (comment)_ we discussed a function wrap_arrayish_value_into_tracer_like to handle cases like

def funct(x, y):
    z = wrap_arrayish_value_into_tracer_like(Value(...), x)
    return x + y + z

In GalacticDynamics/galax#187 I'm encountering the lack of this functionality in the _potential_energy methods, where the Parameters of the potential do not pass through the same quax boundary.

def _potential_energy(self, xyz, t, _G):
    m = self.some_param(t)
    return G * m / ...

In our case some_param is often a callable object with a .value attribute. In testing this self.some_param.value is not wrapped into a tracer when self is wrapped. Would recursively wrapping into the PyTree be able to fix this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.