Giter Club home page Giter Club logo

zodiax's Introduction


Well hello there, I'm Louis! ๐Ÿ‘‹

I'm a PhD Candidate in Astronomy at The University of Sydney working with my supervisors Peter Tuthill and Benjamin Pope.

My research is focused around advancing the way we approach modelling optical systems, aiming to integrate the huge advancements in Machine Learning of the last decade into optical sciences through the use of Automatic Differentiation. I aim to develop new softwares, tools and methods that harness these ideas in order to advance the field as a whole, as well as applying these ideas to the Toliman Space Telescope and JWST.

Please don't hesitate to contact me about project collaboration at [email protected]!


Software:

I'm deeply passionate about open-source software and its ability to precipitate multi-disciplinary scientific advancement. Presently I am leading the development of two packages:
  • โˆ‚Lux: An open-source, fully differentiable, GPU accelerated optical modelling framework.
  • Zodiax: An open-source framework for object-oriented Jax as an extension of Equinox for scientific application.

zodiax's People

Contributors

benjaminpope avatar louisdesdoigts avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

zodiax's Issues

Test and Implement optimised `filter_jit`

Allow input of a parameter list, and partition/combine around this to make all but the parameters of interest 'static'.

This needs testing to see if anything is actually gained, and if this solves the string parameter issue.

from jax import Array
import zodiax as zdx

class Foo(zdx.Base):
    leaf : Array
    unit : str

    def __call__(self, new_unit: str):
        new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
        return self.set(('leaf', 'unit'), (new_value, new_unit))

param = 'leaf'
@zdx.filter_jit(param)
@zdx.filter_grad(param)
def f(pytree):
    return pytree('nm').leaf**2

Nan covariance matrix when parameters have array followed by float

So this is an odd bug, I haven't had time to dig into the cause but here is a minimal example

import jax.numpy as np
import zodiax as zdx
import matplotlib.pyplot as plt
from jax import Array

print(zdx.__version__)

class Foo(zdx.Base):
    a : Array
    b : float

    def __init__(self, a, b):
        self.a = a
        self.b = b
    
    def model(self):
        return self.a ** self.b

foo = Foo(np.array([1, 2, 3], dtype=float), 2.)

> 0.4.0
cov = zdx.self_covariance_matrix(foo, ['a', 'b'], zdx.poiss_loglike)
print(cov)
cov = zdx.self_covariance_matrix(foo, ['b', 'a'], zdx.poiss_loglike)
print(cov)

> [[ nan  nan  nan  nan]
>  [ nan  nan  nan  nan]
>  [ nan  nan  nan  inf]
>  [ nan  nan  nan -inf]]
> [[-0.70989895  0.08762264  0.56825095  1.2324241 ]
>  [ 0.08762264  0.2811295  -0.07013908 -0.1521178 ]
>  [ 0.5682511  -0.07013909 -0.19037326 -0.9865156 ]
>  [ 1.2324241  -0.15211779 -0.98651546 -1.8828206 ]]

As can be seen nans are returned when the input parameter order is ['a', 'b'], but not ['b', 'a'].

Oddly, this does not seem to be an issue with the fisher matrix calculation, which leads me to think this could be part of the matrix inversion.

fish = zdx.self_fisher_matrix(foo, ['a', 'b'], zdx.poiss_loglike)
print(fish)
fish = zdx.self_fisher_matrix(foo, ['b', 'a'], zdx.poiss_loglike)
print(fish)

> [[-3.4253058   0.          0.         -0.42278463]
>  [ 0.         -3.780815    0.         -3.0264199 ]
>  [ 0.          0.         -3.895044   -6.762014  ]
>  [ 0.          0.          0.          0.        ]]
> [[-12.805318    -0.42278463  -3.0264199   -6.7620144 ]
>  [ -0.42278463  -3.4253058    0.           0.        ]
>  [ -3.0264199    0.          -3.780815     0.        ]
>  [ -6.762014     0.           0.          -3.895044  ]]

Enable gradients through filter function wrt floats

Currently inside the wrappers, no args are passed to the call of eqx.filter_, which defaults to eqx.is_array (I believe), excluding floats.

We should be able to fix this with is_array_or_float = lambda leaf: True if isinstance(leaf, float) else eqx.is_array(leaf) and passing that to the equinox call.

`ValueError` when not passing lists into `get_optimiser`

Example

import zodiax
class Foo(zodiax.ExtendedBase):
    bar : np.ndarray

    def __init__(self, bar):
        self.bar = np.array(bar)

foo = Foo([1., 2, 3])
optim, opt_state = foo.get_optimiser('bar', optax.adam(1.))

StackTrace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[119], line 10
      8 foo = Foo([1., 2, 3])
      9 # optim, opt_state = foo.get_optimiser(['bar'], [optax.adam(2e-8)])
---> 10 optim, opt_state = foo.get_optimiser('bar', optax.adam(2e-8))

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/zodiax-0.1.1-py3.11.egg/zodiax/base.py:691, in ExtendedBase.get_optimiser(self, paths, optimisers, get_args, pmap)
    688 optim = multi_transform(opt_dict, param_spec)
    690 # Get filtered optimiser
--> 691 opt_state = optim.init(eqx_filter(self, is_array))
    693 return (optim, opt_state) if not get_args \
    694     else (optim, opt_state, self.get_args(paths, pmap))

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/combine.py:135, in multi_transform.<locals>.init_fn(params)
    130 if not label_set.issubset(transforms.keys()):
    131   raise ValueError('Some parameters have no corresponding transformation.\n'
    132                    f'Parameter labels: {list(sorted(label_set))} \n'
    133                    f'Transforms keys: {list(sorted(transforms.keys()))} \n')
--> 135 inner_states = {
    136     group: wrappers.masked(tx, make_mask(labels, group)).init(params)
    137     for group, tx in transforms.items()
    138 }
    139 return MultiTransformState(inner_states)

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/combine.py:136, in <dictcomp>(.0)
    130 if not label_set.issubset(transforms.keys()):
    131   raise ValueError('Some parameters have no corresponding transformation.\n'
    132                    f'Parameter labels: {list(sorted(label_set))} \n'
    133                    f'Transforms keys: {list(sorted(transforms.keys()))} \n')
    135 inner_states = {
--> 136     group: wrappers.masked(tx, make_mask(labels, group)).init(params)
    137     for group, tx in transforms.items()
    138 }
    139 return MultiTransformState(inner_states)

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/wrappers.py:482, in masked.<locals>.init_fn(params)
    480 def init_fn(params):
    481   mask_tree = mask(params) if callable(mask) else mask
--> 482   masked_params = mask_pytree(params, mask_tree)
    483   return MaskedState(inner_state=inner.init(masked_params))

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/optax/_src/wrappers.py:478, in masked.<locals>.mask_pytree(pytree, mask_tree)
    477 def mask_pytree(pytree, mask_tree):
--> 478   return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree)

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/jax/_src/tree_util.py:206, in tree_map(f, tree, is_leaf, *rest)
    173 """Maps a multi-input function over pytree args to produce a new pytree.
    174 
    175 Args:
   (...)
    203   [[5, 7, 9], [6, 1, 2]]
    204 """
    205 leaves, treedef = tree_flatten(tree, is_leaf)
--> 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/dlux/lib/python3.11/site-packages/jax/_src/tree_util.py:206, in <listcomp>(.0)
    173 """Maps a multi-input function over pytree args to produce a new pytree.
    174 
    175 Args:
   (...)
    203   [[5, 7, 9], [6, 1, 2]]
    204 """
    205 leaves, treedef = tree_flatten(tree, is_leaf)
--> 206 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    207 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Expected list, got Array([1., 2., 3.], dtype=float64).

Enable the setting of static fields?

You can actually set static pytree leaves, this would overall allow for string values to be set to the class, while still having them recognised as pure-jax pytrees.

This issue is the method is mutable which is somewhat dangerous as it operates differently to the rest of the software. We can make the operations somewhat like what we expect, but its behavior is unusual.

import zodiax as zdx

class Foo(zdx.Base):
    value : str = zdx.field(static=True)
    
    def __init__(self):
        self.value = 'foo'
class FooReturned(Foo):
    def set_static(self, key, value):
        object.__setattr__(self, key, value)
        return self

foo = FooReturned()

# This prints the updated object, like we expect an immutable class to work
print(foo.set_static('value', 'bar'))

# This prints the updated object too, even though we never re-assigned the variable 'foo'
print(foo)
FooReturned(value='bar')
FooReturned(value='bar')
class FooNotReturned(Foo):
    def set_static(self, key, value):
        object.__setattr__(self, key, value)

foo = FooNotReturned()
print(foo.set_static('value', 'bar'))

foo = foo.set_static('value', 'bar')
print(foo)
None
None

This can be mitigated with a trade off, the immutable behavior can be achieved if a jax.jit is wrapped around an update function with the both the key and value set to static, we can get the desired behavior at the cost of a potential function re-compile every time the leaf is set to a new value, and that it can only be updated with hashable leaves. This may not be a problem if higher-level jits are able to cancel out the recompile (needs testing).

from functools import partial

class FooException(Foo):
    test : float = 1.

    @partial(jax.jit, static_argnums=(1, 2,))
    def set_static(self, key, value):
        object.__setattr__(self, key, value)
        return self

    def set(self, key, value):
        try:
            return super().set(key, value)
        except AssertionError:
            try:
                return self.set_static(key, value)
            except ValueError:
                raise ValueError(f"Static leaves can only be set to hashable data types. Tried to set the leaf '{key}' with a {type(value)} type.")
        return self

foo = FooException()
print(foo.set('test', 2.))

foo = FooException()
print(foo.set('test', 2.).test)

print(foo.set('value', 'bar'))
print(foo)

foo = foo.set('value', 'bar')
print(foo)
FooException(value='foo', test=f32[])
2.0
FooException(value='bar', test=1.0)
FooException(value='foo', test=1.0)
FooException(value='bar', test=1.0)

The implemented method also has a helpful error if you try to set the leaf to a non hashable value.

foo = FooException()
print(foo.set('value', jax.numpy.ones(2)))
...
ValueError: Static leaves can only be set to hashable data types. Tried to set the leaf 'value' with a <class 'jaxlib.xla_extension.ArrayImpl'> type.

Notes:

  • This needs to also have its jit and gradability checked
  • the error that needs to be caught is an empty assertion error, which is generic and slightly dangerous to catch. It could be possible to talk with Patrick about putting in a specific error message so that it can be directly checked against, ensuring that we are not catching any other errors.
  • Don't forget to add warnings about Jax 'side-effects', in that only one branch of a python logic if-else will be compiled from these static field

Path API

I came across this library when I referenced your tree_at issue with Equinox as one of the motives for PyTreeClass existence. Since PyTreeClass now supports equinox or any Pytree, and adds functional/composable lenses-like setters/getters

Why don't you give it another shot and let me know if it solves your problem?

Looking at this library implementation, this looks very similar to the first version of PyTreeClass, which had some limitations; however, as jax now supports path API, I think it's better to migrate to their API.

For example:

import equinox as eqx
import pytreeclass as pytc
import jax


class Tree(eqx.Module):
    weight: jax.Array = jax.numpy.array([-1, 2, 3])
    bias: jax.Array = jax.numpy.array([1])
    counter: int = 1

    @property
    def at(self):
        return pytc.AtIndexer(self, ())


tree = Tree()

tree = (
    tree.at["counter"]
    .set(1)  # set counter to 1
    .at[jax.tree_map(lambda x: x < 0, tree)]
    .set(0)  # set negative values to 0
    .at["bias"].set(100)  # set bias to 100
)

print(tree.weight)
# [0 2 3]
print(tree.bias)
# 100
print(tree.counter)
# 1

Setting `None` requires wrapping in a list

Setting parameters to None results in a ValueError.

Minimal Example:

import zodiax as zdx
class Foo(zdx.Base):
    param : float = 1.

foo = Foo()

# Works
bar = foo.set(['param'], [None])

# ValueError
bar = foo.set('param', None)

Stack Trace:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [10], in <cell line: 7>()
      5 foo = Foo()
      6 bar = foo.set(['param'], [None])
----> 7 bar = foo.set('param', None)

File ~/mambaforge/envs/dlux/lib/python3.10/site-packages/zodiax/base.py:296, in Base.set(self, paths, values, pmap)
    275 def set(self   : Pytree,
    276         paths  : Union[str, list],
    277         values : Union[Any, list],
    278         pmap   : dict = None) -> Pytree:
    279     """
    280     Set the leaves specified by paths with values.
    281 
   (...)
    294         The pytree with leaves specified by paths updated with values.
    295     """
--> 296     new_paths, new_values = self._format(paths, values, pmap)
    298     # Define 'where' function and update pytree
    299     get_leaves_fn = lambda pytree: pytree._get_leaves(new_paths)

ValueError: not enough values to unpack (expected 2, got 1)

This should be able to be fixed by checking for None as the values input, and automatically wrapping it in a list.

Gradient of float returns None

Doing a gradient descent, and the gradient of a parameter of type float always returns None. This was fixed by changing it to a jax array.

doc suggestions

Hi Louis,

Great work.

Suggestion about the docs: the huge long notebook could be split up. I think a set of very brief notebooks showing

  • get, set, add, pmap
  • optax
  • numpyro
    APIs would be more accessible for a general reader.

The text in the first paragraph at the top could be edited for tone and length and go onto the landing page rather than one of the notebooks.

Cheers,

Ben

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.