Giter Club home page Giter Club logo

dynamax's Introduction

Welcome to DYNAMAX!

Logo

Test Status

Dynamax is a library for probabilistic state space models (SSMs) written in JAX. It has code for inference (state estimation) and learning (parameter estimation) in a variety of SSMs, including:

  • Hidden Markov Models (HMMs)
  • Linear Gaussian State Space Models (aka Linear Dynamical Systems)
  • Nonlinear Gaussian State Space Models
  • Generalized Gaussian State Space Models (with non-Gaussian emission models)

The library consists of a set of core, functionally pure, low-level inference algorithms, as well as a set of model classes which provide a more user-friendly, object-oriented interface. It is compatible with other libraries in the JAX ecosystem, such as optax (used for estimating parameters using stochastic gradient descent), and Blackjax (used for computing the parameter posterior using Hamiltonian Monte Carlo (HMC) or sequential Monte Carlo (SMC)).

Documentation

For tutorials and API documentation, see: https://probml.github.io/dynamax/.

For an extension of dynamax that supports structural time series models, see https://github.com/probml/sts-jax.

For an illustration of how to use dynamax inside of bayeux to perform Bayesian inference for the parameters of an SSM, see https://jax-ml.github.io/bayeux/examples/dynamax_and_bayeux/.

Installation and Testing

To install the latest releast of dynamax from PyPi:

pip install dynamax                 # Install dynamax and core dependencies, or
pip install dynamax[notebooks]      # Install with demo notebook dependencies

To install the latest development branch:

pip install git+https://github.com/probml/dynamax.git

Finally, if you're a developer, you can install dynamax along with the test and documentation dependencies with:

git clone [email protected]:probml/dynamax.git
cd dynamax
pip install -e '.[dev]'

To run the tests:

pytest dynamax                         # Run all tests
pytest dynamax/hmm/inference_test.py   # Run a specific test
pytest -k lgssm                        # Run tests with lgssm in the name

What are state space models?

A state space model or SSM is a partially observed Markov model, in which the hidden state, $z_t$, evolves over time according to a Markov process, possibly conditional on external inputs / controls / covariates, $u_t$, and generates an observation, $y_t$. This is illustrated in the graphical model below.

The corresponding joint distribution has the following form (in dynamax, we restrict attention to discrete time systems):

$$p(y_{1:T}, z_{1:T} | u_{1:T}) = p(z_1 | u_1) p(y_1 | z_1, u_1) \prod_{t=1}^T p(z_t | z_{t-1}, u_t) p(y_t | z_t, u_t)$$

Here $p(z_t | z_{t-1}, u_t)$ is called the transition or dynamics model, and $p(y_t | z_{t}, u_t)$ is called the observation or emission model. In both cases, the inputs $u_t$ are optional; furthermore, the observation model may have auto-regressive dependencies, in which case we write $p(y_t | z_{t}, u_t, y_{1:t-1})$.

We assume that we see the observations $y_{1:T}$, and want to infer the hidden states, either using online filtering (i.e., computing $p(z_t|y_{1:t})$ ) or offline smoothing (i.e., computing $p(z_t|y_{1:T})$ ). We may also be interested in predicting future states, $p(z_{t+h}|y_{1:t})$, or future observations, $p(y_{t+h}|y_{1:t})$, where h is the forecast horizon. (Note that by using a hidden state to represent the past observations, the model can have "infinite" memory, unlike a standard auto-regressive model.) All of these computations can be done efficiently using our library, as we discuss below. In addition, we can estimate the parameters of the transition and emission models, as we discuss below.

More information can be found in these books:

Example usage

Dynamax includes classes for many kinds of SSM. You can use these models to simulate data, and you can fit the models using standard learning algorithms like expectation-maximization (EM) and stochastic gradient descent (SGD). Below we illustrate the high level (object-oriented) API for the case of an HMM with Gaussian emissions. (See this notebook for a runnable version of this code.)

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from dynamax.hidden_markov_model import GaussianHMM

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
num_states = 3
emission_dim = 2
num_timesteps = 1000

# Make a Gaussian HMM and sample data from it
hmm = GaussianHMM(num_states, emission_dim)
true_params, _ = hmm.initialize(key1)
true_states, emissions = hmm.sample(true_params, key2, num_timesteps)

# Make a new Gaussian HMM and fit it with EM
params, props = hmm.initialize(key3, method="kmeans", emissions=emissions)
params, lls = hmm.fit_em(params, props, emissions, num_iters=20)

# Plot the marginal log probs across EM iterations
plt.plot(lls)
plt.xlabel("EM iterations")
plt.ylabel("marginal log prob.")

# Use fitted model for posterior inference
post = hmm.smoother(params, emissions)
print(post.smoothed_probs.shape) # (1000, 3)

JAX allows you to easily vectorize these operations with vmap. For example, you can sample and fit to a batch of emissions as shown below.

from functools import partial
from jax import vmap

num_seq = 200
batch_true_states, batch_emissions = \
    vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))(
        jr.split(key2, num_seq))
print(batch_true_states.shape, batch_emissions.shape) # (200,1000) and (200,1000,2)

# Make a new Gaussian HMM and fit it with EM
params, props = hmm.initialize(key3, method="kmeans", emissions=batch_emissions)
params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20)

These examples demonstrate the dynamax models, but we can also call the low-level inference code directly.

Contributing

Please see this page for details on how to contribute.

About

Core team: Peter Chang, Giles Harper-Donnelly, Aleyna Kara, Xinglong Li, Scott Linderman, Kevin Murphy.

Other contributors: Adrien Corenflos, Elizabeth DuPre, Gerardo Duran-Martin, Colin Schlager, Libby Zhang and other people listed here

MIT License. 2022

dynamax's People

Contributors

andrewwarrington avatar calebweinreb avatar canyon289 avatar edeno avatar emdupre avatar ezhang94 avatar gerdm avatar gileshd avatar jakevdp avatar jasondavies avatar karalleyna avatar kostastsa avatar matthew9671 avatar murphyk avatar partev avatar patel-zeel avatar petergchang avatar raulpl avatar schlagercollin avatar slinderman avatar xinglong-li avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dynamax's Issues

Implement EKF

Implement EKF, EK smoother, unit tests, pendulum demo.

blocked gibbs sampling for LGSSM as alternative to EM

Implement blocked gibbs sampling for LGSSM. Then make a GS version of this EM demo:
https://github.com/probml/ssm-jax/blob/main/ssm_jax/lgssm/demos/lgssm_learning.py

Some details can be found in this paper
A. Wills, T. B. Schön, F. Lindsten, and B. Ninness, “Estimation of Linear Systems using a Gibbs Sampler,” IFAC proc. vol., vol. 45, no. 16, pp. 203–208, Jul. 2012, doi: 10.3182/20120711-3-be-2027.00297. [Online]. Available: https://linkinghub.elsevier.com/retrieve/pii/S1474667015379520

cleanup LGSSM code

  • Extend inference_test to test lgssm_posterior_sample, lgssm_smoother and the log-likelihoods from lgssm_filter
  • Refactor kf_tracking so it works with new API, and can run in 'silent' mode (no figures)
  • Convert kf_parallel
  • Convert kf_spiral
  • convert linreg_kf
  • Create run_all_demos script to check if all demos run without errors
  • Update inference.py so the comments are correct, and filter and smoother both return a LGSSMPosterior object.

JSL demos we can ignore

fix hmm/models.py

Currently all the HMM variants are in their own files and are not imported into models.py, so their names are not discoverable, so all HMM tests fail.

cleanup HMM SGD code

Rename hmm_fit_minibatch_gradient_descent in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L75
to be hmm_fit_sgd.
Rename emissions to be batch_emissions. Add commetn that input is (N,T)
but you take a minibatch of size (B,T) at each step.

Remove old hmm_fit_sgd.

Move permutation step in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L93
inside of _sample_minibatches. Split the RNG key.
(Also check if B=N no need to do random permutation)
Add comment that you are sampling a random susbet of entire sequence, not time steps.

add GitHub workflow

As requested by Dr. @murphyk, the following tasks are to be implemented as part of this issue.

  • run tests following wildcard *_test.py
  • check if the code is black formatted (as per the config mentioned in pyproject.toml)

Standard HMM m_step interface

Currently, BaseHMM.m_step and GaussianHMM.m_step have different inputs and outputs, which breaks the standard hmm_fit_em function. I suggest we standardize on,

@classmethod
def m_step(self, batch_emissions, batch_posteriors, **kwargs):
    ...
    return hmm

and

def hmm_fit_em(hmm, batch_emissions, num_iters=50, **kwargs):
    @jit
    def em_step(hmm):
        batch_posteriors, marginal_logliks = hmm.e_step(batch_emissions)
        hmm = hmm.m_step(batch_emissions, batch_posteriors, **kwargs)
        return hmm, marginal_logliks.sum(), batch_posteriors

(Separately, I'm starting to question whether m_step should be a class method. Maybe we should just embrace the objects and set their parameters within the M step, rather than returning a new object. We're already functional at level of the underlying inference code.)

add support for MAP parameter estimation to GaussianHMM

Currently GaussianHMM.m_step computes the MLEs for mu_k, Sigma_k and the MAP estimate for the transition matrix A (the latter uses a weak regularizing Dirichlet prior). Modify this to allow for an optional Normal-Inverse-Wishart prior to be specified for p(mu_k,Sigma_k|z=k). Modify M step so EM supports MAP as well as MLE.

Also make it possible to add log prior to the log marginal likelihood so we can also compute the MAP estimate using SGD

Add a unit test to learning_test.py

pendulum example

Please create separate demos for pendulum_ekf and pendulum_ukf for the Gaussian noise (no outliers) version of
ttps://github.com/probml/JSL/blob/main/jsl/demos/pendulum_1d.py

refactor hmm/models.py

Our hmm/models.py is getting too big. We could just use that for BaseHMM and factor out each subclass into its own file, so we would have ar_hmm.py, gaussian_hmm.py, categorical_hmm.py, poisson_hmm etc. This is more modular since we can encapsulate model-specific logic (eg M step) into separate files.

Implement UKF

Implement UKF, UK smoother, unit tests, pendulum demo.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.