probml / dynamax Goto Github PK
View Code? Open in Web Editor NEWState Space Models library in JAX
Home Page: https://probml.github.io/dynamax/
License: MIT License
State Space Models library in JAX
Home Page: https://probml.github.io/dynamax/
License: MIT License
Currently GaussianHMM.m_step computes the MLEs for mu_k, Sigma_k and the MAP estimate for the transition matrix A (the latter uses a weak regularizing Dirichlet prior). Modify this to allow for an optional Normal-Inverse-Wishart prior to be specified for p(mu_k,Sigma_k|z=k)
. Modify M step so EM supports MAP as well as MLE.
Also make it possible to add log prior to the log marginal likelihood so we can also compute the MAP estimate using SGD
Add a unit test to learning_test.py
Modify inference.py in HMM to support "ragged" sequence lengths.
Check that EM works on
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/demos/hmm_bach_chorales.ipynb
See sec 29.5.1 of https://probml.github.io/pml-book/book2.html
and this review paper:
S. Chiappa, “Explicit-Duration Markov Switching Models,” Foundations and Trends in Machine Learning, vol. 7, no. 6, pp. 803–886, 2014, doi: 10.1561/2200000054. [Online]. Available: http://www.nowpublishers.com/article/Details/MAL-054
Should reproduce a figure similar to https://github.com/probml/JSL/blob/main/jsl/demos/ekf_vs_ukf.py
Rename hmm_fit_minibatch_gradient_descent
in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L75
to be hmm_fit_sgd
.
Rename emissions
to be batch_emissions
. Add commetn that input is (N,T)
but you take a minibatch of size (B,T) at each step.
Remove old hmm_fit_sgd
.
Move permutation step in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L93
inside of _sample_minibatches
. Split the RNG key.
(Also check if B=N no need to do random permutation)
Add comment that you are sampling a random susbet of entire sequence, not time steps.
The ssm-jax.GaussiaHMM class only supports full covariance matrices, one per state. But hmmlearn.GaussianHMM supports more options (see below). We should support this too.
(n_components, ) if “spherical”,
(n_components, n_features) if “diag”,
(n_components, n_features, n_features) if “full”,
(n_features, n_features) if “tied”.
Currently all the HMM variants are in their own files and are not imported into models.py, so their names are not discoverable, so all HMM tests fail.
Reimplement https://github.com/probml/pyprobml/blob/master/notebooks/book2/28/hmm_bach_chorales.ipynb to
use ssm_jax/hmm/learning.py
and add to ssm_jax/hmm/demos/bach_chorales_sgd
Implement UKF, UK smoother, unit tests, pendulum demo.
Add hmm_forwards_filtering_backwards_sampling_jax from JSL to https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/inference.py. Add a unit test.
Implement some of the hierarchical dirichlet process HDP-HMM algorithms from https://github.com/mattjj/pyhsmm
Currently, BaseHMM.m_step
and GaussianHMM.m_step
have different inputs and outputs, which breaks the standard hmm_fit_em
function. I suggest we standardize on,
@classmethod
def m_step(self, batch_emissions, batch_posteriors, **kwargs):
...
return hmm
and
def hmm_fit_em(hmm, batch_emissions, num_iters=50, **kwargs):
@jit
def em_step(hmm):
batch_posteriors, marginal_logliks = hmm.e_step(batch_emissions)
hmm = hmm.m_step(batch_emissions, batch_posteriors, **kwargs)
return hmm, marginal_logliks.sum(), batch_posteriors
(Separately, I'm starting to question whether m_step
should be a class method. Maybe we should just embrace the objects and set their parameters within the M step, rather than returning a new object. We're already functional at level of the underlying inference code.)
Implement blocked gibbs sampling for LGSSM. Then make a GS version of this EM demo:
https://github.com/probml/ssm-jax/blob/main/ssm_jax/lgssm/demos/lgssm_learning.py
Some details can be found in this paper
A. Wills, T. B. Schön, F. Lindsten, and B. Ninness, “Estimation of Linear Systems using a Gibbs Sampler,” IFAC proc. vol., vol. 45, no. 16, pp. 203–208, Jul. 2012, doi: 10.3182/20120711-3-be-2027.00297. [Online]. Available: https://linkinghub.elsevier.com/retrieve/pii/S1474667015379520
In categoricalHMM, there is an E step (https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/models.py#L332) which is the same as BaseHMM. Remove the one from Categorical and check that tests still work (eg casino)
In poisson_hmm_changepoint, we 'freeze' the start and transition probabilities, and just update the emission (observation) model. In the hmmlearn you can specify which parameters you want to learn. eg for PoissonHMM, you can say params="l"
to just update the "lambdas" (ie emissions). We should support the same feature so that we can simplify our poisson_hmm_changepoint demo.
This is to track the 'porting' of all the LDS examples
from https://github.com/lindermanlab/ssm-jax/tree/main/notebooks
Reimplement https://github.com/probml/pyprobml/blob/master/notebooks/book2/29/hmm_bernoulli.ipynb
to use ssm-jax
Currently computes MLE, may suffer from - counts
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/models/categorical_hmm.py#L65
Instead just use
pseudo_counts = 0.1
B(k, :) = normalize(expected_counts(s,k) + pseudo_counts)
lgssm_posterior_sample
, lgssm_smoother
and the log-likelihoods from lgssm_filter
run_all_demos
script to check if all demos run without errorsLGSSMPosterior
object.JSL demos we can ignore
add GaussianARHMM to models.py and then reimplement
https://github.com/probml/pyprobml/blob/master/notebooks/book2/28/arhmm_example.ipynb
Reimplement https://github.com/probml/pyprobml/blob/master/notebooks/book2/28/hmm_poisson_changepoint_jax.ipynb
to use ssm-jax instead of distrax and add to ssm_jax/hmm/demos
@karalleyna
As requested by Dr. @murphyk, the following tasks are to be implemented as part of this issue.
*_test.py
pyproject.toml
)Currently LGSSM.m_step computes the MLE. We should allow the user to specify a NIW prior on all the matrices and then compute the MAP estimate. (Similar functionality will also needed by Gibbs sampling for HMM fitting.)
Replace beta.mode in https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/models/bernoulli_hmm.py#L99
with pseudo counts
Implement EKF, EK smoother, unit tests, pendulum demo.
The m_step
in BaseHMM in https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/models.py#L165 is incorrect.
Remove it for now.
Read up on "Generalized EM" algorithm and then test your implementation on GaussianHMM.
Please create separate demos for pendulum_ekf
and pendulum_ukf
for the Gaussian noise (no outliers) version of
ttps://github.com/probml/JSL/blob/main/jsl/demos/pendulum_1d.py
I tried this
try:
import ssm_jax
except:
%pip install git+https://github.com/probml/ssm-jax
import ssm_jax
but get the error
No module named 'ssm_jax'
Update https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/demos/demos_test.py so it runs all the scipts in test mode.
Currently https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/demos/casino_training.ipynb
relies to hmm_fit_minibatch, should change to hmm_fit_sgd
Implement closed form M step and add to https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/models.py#L572
reimplement https://github.com/probml/pyprobml/blob/master/notebooks/book2/28/bernoulli_hmm_example.ipynb
to use ssm_jax/hmm/learning.py
and add to ssm_jax/hmm/demos
. (You may need to extend the models.py
file to support product-of-bernoulli likelihoods)
This is to track the 'porting' of all the HMM examples
from https://github.com/lindermanlab/ssm-jax/tree/main/notebooks
Should reproduce a figure similar to https://github.com/probml/JSL/blob/main/jsl/demos/ekf_vs_ukf.py
Reimplement https://github.com/probml/JSL/blob/main/jsl/demos/ekf_mlp.py on top of ssm-jax EKF code
Our hmm/models.py is getting too big. We could just use that for BaseHMM and factor out each subclass into its own file, so we would have ar_hmm.py, gaussian_hmm.py, categorical_hmm.py, poisson_hmm etc. This is more modular since we can encapsulate model-specific logic (eg M step) into separate files.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.