Giter Club home page Giter Club logo

Comments (1)

aseyboldt avatar aseyboldt commented on June 19, 2024

Hi :-)

You are right, we could really use some documentation on that. If you can share a notebook that we can add to the docs, that would be great!

If you use solve_ivp from the as_pytensor (the old as_aesara) module, the adjoint solver will be used by default controlled by the derivatives: str = 'adjoint' argument).

About your questions:

The adjoint solver corresponds to the backward step in reverse mode autodiff, or the pullback from differential geometry.

We assume that we want to compute the gradient of some large function $h: \mathbb{R}^n \to \mathbb{R}$. In an application that could for instance be a posterior log probability function, that maps parameter values to their unnormalized density. We split this large function into smaller parts, and one of those parts would be the function $f$ that solves the ODE, so $h(x) = g(f(x))$, where $f$ is a function $f: \mathbb{R}^n \to \mathbb{R}^m$. The $\mathbb{R}^n$ contains all parameters, initial conditions, and time points where we want to evaluate the solution. $\mathbb{R}^m$ contains the solution of the ode at those points. And the other part $g$ is the function that maps the solution of the ode to a log prob value (ie a likelihood). When we compute the gradient of $h$ we can isolate the contribution of the function $f$ using the chain rule, and basically ask: If we already know the gradient of the later parts of $h$, namely $g$, what would then be the gradient of $h$? So we define a function that takes those gradients of $g$ as input, and returns the gradients of $h$. This is exactly what happens in solve_backward. The gradients of $g$ are called grads in the code. The final gradients of $h$ are split in two parts: grads_out, for the gradients with respect to the parameters, and lambda_out for the gradients with respect to the initial conditions (-lambda_out actually, that's how this was defined by sundials for some reason...).

This is essentially also how sundials does things internally, only that it generalizes it a bit more. The idea is that the way we think about "the function that solves the ODE $f$" isn't as general as it could be. Instead of just asking what the solution will be at certain points, we could also say that $f$ should return the solution function. Which means that $f$ is a function from parameters and initial conditions to the solution function of the ODE. And correspondingly $g$ would then be a function that takes a function as input and returns a scalar. This allows some things that sunode doesn't support
currently, like computing the gradient of an integral over the solution. So for instance you could have a loss function that compares the solution function to a target solution.

In what context are you using sunode? If you don't use the pytensor wrappers, you'll have to apply the chain rule yourself to get gradients of the composite function.

I hope this explanation is helping at least a bit, feel free to ask for clarification if something is not clear, this isn't the easiest subject to write about. :-)

from sunode.

Related Issues (18)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.