Giter Club home page Giter Club logo

pgmax's Introduction

continuous-integration PyPI version Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our companion paper for more details.

PGMax is under active development. APIs may change without notice, and expect rough edges!

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/deepmind/PGMax.git

Developer

While you can install PGMax in your standard python environment, we strongly recommend using a Python virtual environment to manage your dependencies. This should help to avoid version conflicts and just generally make the installation process easier.

git clone https://github.com/deepmind/PGMax.git
cd PGMax
python3 -m venv pgmax_env
source pgmax_env/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 setup.py develop

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

PGMax is part of the DeepMind JAX ecosystem. If you use PGMax in your work, please consider citing our companion paper

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Dedieu, Antoine and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

and using the DeepMind JAX Ecosystem citation

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}

Note

This is not an officially supported Google product.

pgmax's People

Contributors

antoine-dedieu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

pgmax's Issues

Calculating the Log Partition Function in LBP

How can we calculate the log partition function, given that we can obtain node marginals through inferer.get_beliefs? Is it also feasible to acquire factor marginals for this purpose? While the typical approach to compute the log partition function is as follows:

$$log Z \approx \sum_{i\in V} (1-d_i)H_i + \sum_{(i,j)\in E} I_{ij}$$

This formula requires both the entropy $H_i$ for individual nodes and the mutual information $I_{ij}$ for edges. To compute these, we need the node marginals $b_i(x_si)$ and the joint marginal probabilities $b_{ij}(x_i,x_j)$.

Installation seems to be broken

I just attempted to install PGMax via pip: i.e. pip install pgmax. While it did work, I noticed that it didn't seem to install any of the dependencies. Indeed, when I opened up a terminal and ran from pgmax import fgraph, I got a ModuleNotFoundError: No module named 'jax'.

The core issue seems to be that there are no requirements listed in the setup.py file, and the requirements.txt file is also not linked to setup.py in any way (indeed, this commit seems to have removed the link between setup.py and requirements.txt). If I clone the repo and run pip install -r requirements.txt, then things work fine, but the pip installation does still seem currently broken.

Problem in colab notebooks

Hi.
When running notebooks in colab, for example rcn.ipynb, there is a problem in the installation cell:
image
This is because !cd PGMax works locally and not globally.
Solution: replace this line to %cd PGMax
P.S: I can make PR for this.

some (very) small suggestions on the examples

  • In examples/gmrf.ipynb, maybe replace from jax.example_libraries import optimizers with optax. Also the term "GMRF" usually refers to Gaussian MRF, not grid :) Finally the learning code is super slow, even on an A100... (public colab)
  • in examples/rbm.ipynb, you mix jax.vmap with np.random, which seems dangerous...
  • in examples/ising_model.ipynb, the line batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0)) is very obscure. What is the dict with a key called 'variables'? And in what sense are these evidence / potential 'updates', as opposed to just static things?
  • it would be helpful to add a link to the ising example from https://github.com/deepmind/PGMax/tree/main#getting-started

VarDict issues

I'm trying to use named variables using vgroup.Vardict, but without much success. I cannot find any reference to or usage of Vardict. Specifically I'm running into an issue when trying to update evidence.

In pgmax/vgroup/vdict.py the function flatten expects a mapping from the varnames to e.g. evidence, but in pgmax/infer/bp_state.py only the evidence is passed -> flat_data = name.flatten(data)

An example snippet:

root_variables = vgroup.VarDict(num_states=2, variable_names=("a"))
leaf_variables = vgroup.VarDict(num_states=2, variable_names=("b"))
fg = fgraph.FactorGraph(variable_groups=[root_variables, leaf_variables])

r1_factor = factor.EnumFactor(
      variables=[root_variables["a"]],
      factor_configs=np.array([[0], [1]]),
      log_potentials=np.array([math.log(0.9), math.log(0.1)]),
)
fg.add_factors(r1_factor)

pairwise_factor = factor.EnumFactor(
    variables=[root_variables["a"],leaf_variables["b"]],
    factor_configs=np.array(list(itertools.product(np.arange(2), repeat=2))),
    log_potentials=np.array([math.log(0.9), math.log(0.1), math.log(0.2), math.log(0.8)]),
)
fg.add_factors(pairwise_factor)

evidence_updates={root_variables: np.array([math.log(0.9), math.log(0.1)])}
inferer = infer.build_inferer(fg.bp_state, backend="bp")
inferer_arrays = inferer.init(evidence_updates=evidence_updates)

Snippet fails at the init step when trying to update evidence.
I'm not sure if this is an actual issue or whether I'm simply wrong in the way I created the graph.

It would be great to have an example of inference in a directed graphical model. e.g. just the standard sprinkler example would be really useful for getting a handle on how to use the package.

Thank you for the nice library and some questions

Dear PGMax Team,

first of all: Thank you very much for the library, it works super well and we found it to be orders of magnitude faster than other libraries. Furthermore, the flattened potentials are super neat :)

If you don't mind, I would have two questions and would be very grateful if you could provide some pointers:

  • prob == 0 we have unary potentials that can be (numerically) zero. This leads to errors where the prediction seems to completely break down for some variables, although there are variable -> state assignments that would not hit the prob == 0. Initially, we replaced the log potentials at -np.inf with your pgmax.utils.NEG_INF = -1e20. While this fixed some issues, we still got random assignments in some cases that seem to disappear if we use say -20 as the smallest unary potential. I have a hunch that this happens specifically for temperature=0 when getting map_states but couldn't investigate. What is the guidance on this? Could you point towards a good paper that has some experiments regarding this? Is there an intuitive explanation why LBP fails in this case?

  • Initialization. Currently, we use bp_arrays = bp.run_bp(bp.init(), num_iters=num_iters, damping=damping) as given in most tutorials. This seems to init the log_potentials as 0, right? Is there some way we could use prior knowledge to do this? Maybe init with the most likely states given the unary potentials?

  • If you have any other papers that give good hints on how to squeeze the most out of LBP, I would be happy to read them.

Best and Thanks
Lukas

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.