Giter Club home page Giter Club logo

dlux's Introduction

alt text

∂Lux

PyPI version License integration Documentation

Differentiable Optical Models as Parameterised Neural Networks in Jax using Zodiax

Contributors: Louis Desdoigts, Jordan Dennis, Adam Taras, Max Charles, Connor Langford, Benjamin Pope, Peter Tuthill

∂Lux is an open-source differentiable optical modelling framework harnessing the structural isomorphism between optical systems and neural networks, giving forwards models of optical systems as parametric neural networks.

∂Lux is built in Zodiax, which is an open-source object-oriented Jax framework built as an extension of Equinox for scientific programming. This framework allows for the creation of complex optical systems involving many planes, phase and amplitude screens in each, and propagates between them in the Fraunhofer or Fresnel regimes. This enables fast phase retrieval, image deconvolution, and hardware design in high dimensions. Because ∂Lux models are fully differentiable, you can optimize them by gradient descent over millions of parameters; or use Hamiltonian Monte Carlo to accelerate MCMC sampling. Our code is fully open-source under a 3-clause BSD license, and we encourage you to use it and build on it to solve problems in astronomy and beyond.

The ∂Lux framework is built in Zodiax, which gives it a deep range of capabilities from both Jax and Equinox:

For an overview of these capabilities and different optimisation methods in Zodiax, please go through this Zodiax Tutorial.

Documentation: https://louisdesdoigts.github.io/dLux/

Requires: Python 3.10+, Jax 0.4.13+, Zodiax 0.4+

Installation: pip install dLux

If you want to run the tutorials locally, you can install the 'extra' dependencies like so: pip install 'dLux[extras]'

Collaboration & Development

We are always looking to collaborate and further develop this software! We have focused on flexibility and ease of development, so if you have a project you want to use ∂Lux for, but it currently does not have the required capabilities, have general questions, thoughts or ideas, don't hesitate to email me or contact me on twitter! More details about contributing can be found in our contributing guide.

Publications

We have a multitude of publications in the pipeline using dLux, some built from our tutorials. To start we would recommend looking at this invited talk on ∂Lux which gives a good overview and has an attached recording of it being presented! We also have this poster!

dlux's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

dlux's Issues

Logical Cohesion of `Layer` Objects

I was indexing the layers.py file and noticed that we had both ApplyBasicOPD and ApplyOPD. I was wondering if we could amalgamate (that's a lot of a's) both of these into a WavePlate/PhasePlate class. The functionality of the call method is similar, one adding a scalar the other an array. By no means an important or necessary change, I think it would help the code match to tangible objects although it may not be possible.

Let me know what you think.

Propagator Structure

Hello all,
While working on the GaussianPropagator class I noticed that there was some code duplication in the FresnelProp class, which implements a minimal mft. Additionally, the long term plan for the GaussianPropagator was that it would use MFT where possible. This had me thinking that it might be worth moving our mft and fft algorithms into an (abstract) Propagator class. From here we could access them in FresnelProp as self._mft and self._fft and similarlyin GaussianPropagator. The reason I say that this class is abstract is because I would also like to have a __call__(self, distance, wavefront)method defined without implementation (raising an error). This is valuable because the dunder method __call__ is not particularly beginner freindly, but it is how we have implemented the actuall propagation in MFT, FFT and FresnelProp. This means that it is enforced by the OpticalSystem during iteration and is a required behaviour of any user implemented propagators.To build on my suggestion above, we could have FraunhoferPropagator(Propgator), FrenselPropagator(Propagator) and GaussianPropagator(Propagator) as the "front end" propagators. @LouisDesdoigts and I need to discuss the ownership of the optical transfer functions, since I had put mine in the GaussianWavefront class but I notice that his propagators seem to implicitly contain the transfer functions. A dummy implementation of just the function signatures is available in the Propagator branch which I just shared. I will not proceede further until I have more feedback.

As always let me know what you think,

Jordan


`Image` Class to Reduce repetition

Hi all,
A large number of classes store pixel arrays a.k.a images. There is certain functionality that is associated with this, for example the generation of para-axial arrays. I have tried to keep the naming consistent but we have Propagator._get_pixel_grid and Wavefront.get_pixel_positions which do the same thing. There is the option to make Wavefront.get_pixel_positions an @staticmethod with npix as a parameter.

Aside: A static method does not require the class to instantiate an object and can literally be called like Wavefront.get_pixel_positions(npix).

While this is probably the neatest solution at present it does not make much physical sense. Why should the wavefront describe that behaviour? However, I do think we should include the logic of the function in only one place. I think this because Propagator and Wavefront actually defined the para-axial centre differently and when I discussed this with @LouisDesdoigts we decided that it was a bug. Additionally, some Layer classes, use this functionality, for example the CircularApperture.

The typical OOP solution is to build a class Image or ImageHolder that the other classes can inherit from (thankfully python supports multiple inheritance). This would hold the parameter _image : Array and define a series of useful methods.

class Image():
    _image : Array
    
    # NOTE: names are not final, just suggestions
    def get_pixel_coordinates() -> Vector: # Along edge
    def get_pixel_positions() -> Tensor: # 2 stacked grids in 3rd order tensor 
    def get_radial_positions() -> Matrix: # stacked grid

The downfall of this is that you get a very complicated class hierarchy (still quite simple in our case), which is the general downfall of OOP.

A further option is to register all the methods of the class I describe above as @staticmethods making the nearest thing to a static class python provides. With reference to #58 this could go into src/utilities/image.py and then usage would be through Image.get_pixel_coordinates(npix) ect.

I quite like the idea of implementing such a static class, because it also avoids subclasses ending up with lots of methods that they never use as can happen with OOP programming.

As always let me know what you think.

Regards

Jordan

Notebooks need updating on layer internals

Was trying to rerun Instrument Calib and ran into errors in block 6, line 5 and had to change to det_npix = lays[-1].pixels_out. I suspect this isn't the only case. Going forward, maybe get/set methods could avoid some of this but at the moment I think reaching into the layer is fine.

Combine `PhysicalWavefront.__init__` and `CreateWavefront`

Hi all,
While writing tests I have noticed that having the constructor split between two classes leaves a lot of None values exposed within the pytree. For example if I want to test the PhysicalWavefront.get_xycoords(), I must type:

import dLux

wavefront = dLux.PhysicalWavefront(wavel, offest)
# Calling wavefront.get_xycoords() here will through a error because self.pixelscale 
# is None.
wavefront = dLux.CreateWavefront(npix, wavefront_size)({"Wavefront": wavefront})
coordinates = wavefront.get_xycoords()

My proposal is that we add npix and wavefront_size as **kwarg parameters to the PhysicalWavefront constructor. We can then add a simple None check for example:

class PhyscialWavefront(equinox.Module):
    def __init__(self, wavel, offset, /, npix = None, wavefront_size = None, 
            amplitude = None, phase = None):
        self.wavel = wavel
        self.offset = offset
        self.pixelscale = default_value if not npix and wavefront_size
        self.amplitude = default_value if not amplitude
        self.phase = default_value if not phase

That way every attribute is assigned in one place. This is similar to what I have implemented in dLux.GaussianWavefront and we will discuss it in #42. Another motivator is that using the combination of CreateWavefront and PhysicalWavefront.__init__ create 4 PhysicalWavefront objects as equinox.tree_at clones the pytree and applies the modification and we then reassign the pointer to the same variable in scope.

The main counter argument is the logical requirement, which can cause jax issues. I imagine that we only need one PhysicalWavefront per wavel so based on my, probably wrong assumptions, we would be instantiating these before the jit and grad transformed elements of the program. If this is untrue and we need to be able to trace through the constructor the syntax changes from:

self.parameter = default_value_of_parameter if not parameter

to:

self.parameter = jax.lax.cond(parameter == None, 
    lambda: default_value_of_parameter, lambda: parameter, parameter)

with the downside being that jax logic calls are actually, or rather used to be, much slower than python logic calls. Let me know what you think as I am sure there is a good reason for the separation.

Regards

Jordan

`Wavefront`

Hi all,
After our meeting on Friday @LouisDesdoigts and I agreed that a base class Wavefront would be required to keep track of the mutators and accessors already existing in PhysicalWavefront. Specifically the functionality that we discussed to include in this class was:

  • get_xycoords() and associated helper functions.
  • update_phasor()

I am creating this issue to keep track of this change and request further feedback on the division of labour between the classes. I have created a new branch of GaussianWavefront called Wavefront to keep track of this progress.

Add GoogleColab set up to docs

Windows/Google Colab Quickstart

jaxlib can be problematic on windows so we suggest users run our library on Google Colab.
There are a few extra steps to get setup

  1. At the top of each colab file you will need
!git clone https://github.com/LouisDesdoigts/dLux.git # Download latest version
!cd dLux; pip install . # Navigate to ∂Lux and install from source
!pip install latex # Not needed, but just to make plots nicer
!apt install texlive texlive-latex-extra texlive-fonts-recommended dvipng cm-super
  1. If you are using a notebook from our tutorials, you should add the notebooks folder to the path for imports to work
import sys
sys.path.insert(0,  './dLux/notebooks')


Tips and Tricks

  • You can read/write data from your own drive using
from google.colab import drive
drive.mount('/content/drive')
  • View the files using the left sidebar to navigate

Wavefront type-promotion in the CreateWavefront class

Currently the Opticalsystem instantiates a PhysicalWavefront without an option to change this behavior. Since we now have new wavefront types in the pipeline we want to be able to handle type promotion in the CreateWavefront class. I believe @Jordan-Dennis did some initial investigation.'

Other wavefront types will not be natively supported until this is completed.

Offset for `PhysicalFresnel`

Hi all,
At present an offset capability has been added for the PhysicalFresnel algorithm. At present it approximates the offset as been in the focal plane, which is obviously incorrect. I believe that it should be some straightforward trigonometry to correct this to exactness but at present @LouisDesdoigts and I decided it was good enough. This issue merely documents the need to update it at a later point. See #48 for the structure of the Propagators.

Regards

Jordan

Unnecessary Code

Hi there,
Saw the following in the CreateWavefront class in layers.py, where I have commented the question.

    def __call__(self, params_dict):
        """
        
        """
        # Get relevant parameters
        WF = params_dict["Wavefront"]
        ampl = np.ones([self.npix, self.npix])
        phase = np.zeros([self.npix, self.npix])
        pixelscale = self.pixelscale # We assign the local pixel scale here and then 
        # Only ever use it to reassign the class variable below, where I have marked.
    
        # # Update Wavefront Object
        WF = eqx.tree_at(lambda WF: WF.amplitude,  WF, ampl,  
                         is_leaf=lambda x: x is None)
    
        WF = eqx.tree_at(lambda WF: WF.phase,  WF, phase,  
                         is_leaf=lambda x: x is None)
    
        # Should always be the same value. Not sure why this is necessary. 
        WF = eqx.tree_at(lambda WF: WF.pixelscale, WF, pixelscale, 
                         is_leaf=lambda x: x is None)
        params_dict["Wavefront"] = WF
        return params_dict

Development of optical components layer script

So we want to support arbitrary optical systems configurations. This would mean constructing a set of very general layers for common optical components, ie lenses, phase plates, etc. @Jordan-Dennis has made a start on a structure for this script which looks rather good. In the long run we could build a minimal web-scraper to map specific optical components like this to their exact chromatic response and manufacturer specifications.

Integration with manufacturers is probably a unnecessary long-term goal, but the basic optical components script is desired.

Restrucutre /src

So the /src folder is getting a bit unwieldy. I propose we create a sub-directory called extras or something to store the non-core functionality scripts.

The idea would be the scripts recently moved into /src such as bayes.py, helpers.py, plotting.py would live in /extras as a collection of scripts that are designed to work with ∂lux but are not core and therefore not documented or tested.

Over time some of this functionality would be lifted to /src, such as a subset of the helpers.py functions as they are widely applicable to the usage of the package.

This could also be a good place to put a development.py script to place layers in development (ie not documented and tested) so that we can all have access and use consistent versions. This would be very useful for the development of the JWST models so we can put code to share without while still being flexible.

Fresnel Normalisaton

Hi all,
The PhysicalFresnel algorithm is not currently normalised. It seems to be a factor of 10 out in the tests/integration.py test case. @LouisDesdoigts and I discussed this at length and were unable to find the underlying cause, which we postulated was due to the mft-fft clash. The poppy version uses fft and we use mft. 

Test removal of static_field() attributes and control jax tracing through python/jax data types

Ideally, we want to remove any static_field() attributes and control jax/equinox tracing using basic data types, ie: int(x), float(x), onp.array(x) vs np.array(x).astype(float).

We may need to somehow set the default filtering behavior for equinox.filter functions to equinox.is_array (True for jax arrays, false for numpy arrays) or equinox.is_inexact_array (True for floating point jax arrays)

Wavefronts.py strucutre renaming

So now that we have hierarchical wavefront structure I suggest we re-name them for simplicity. The current structure is as follows:

Wavefront:

  • PhysicalWavefront
    • AngularWavefront
  • GaussianWavefront

Under this proposed change the classes would look like this:

BaseWavefront:

  • Wavefront
    • AngularWavefront
  • GaussianWavefront

The aim here is to abstract the abstract base class and help remove the clunky 'Physical' terminology currently employed. This might precipitate a small naming convention change to the propagators but this isn't necessary.

Add 'pixels=False' flag to MFT based propagators

Simple flag input allowing for the 'offset' values passed to be interpreted in pixel rather than physical units. This is definitely behavior we want to allow, but not the highest on the priority list.

Typing

Hello again,
Another result from @LouisDesdoigts and my meeting on Friday was
the use of type annotations. I like them but @LouisDesdoigts
doesn't, so I wanted to get a second opinion from @benjaminpope
before I removed them. Let me know what you think. 

Accessors in `PhysicalWavefront`

Hi Guys,
I was wondering if we wanted to add accessors, i.e. getters, for some of the attributes in PhysicalWavefront. I have been writing tests today and yesterday and these require a knowledge of the state. We can of course, since this is python, access them directly. The PEP8 guidelines do recommend getters from memory. Let me know what you think.

Regards

Jordan

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.