A library for probabilistic models using Haiku and JAX
Project description
Ramsey
Probabilistic modelling using Haiku and JAX
About
Ramsey is a library for probabilistic models using Haiku and JAX. It builds upon the same module system that Haiku is using and is hence fully compatible with Haiku's and NumPyro's API. Ramsey implements (or rather intends to implement) neural and Gaussian process models and normalizing flows.
Installation
To install the latest GitHub release, just call the following on the command line:
pip install git+https://github.com/dirmeier/ramsey@v0.0.1
See also the installation instructions for Haiku and JAX.
Example usage
Ramsey uses to Haiku's module system to construct probabilistic models and define parameters. For instance, a simple neural process can be constructed like this:
import haiku as hk
import jax.random as random
from ramsey.data import sample_from_sinus_function
from ramsey.models import NP
def neural_process(**kwargs):
dim = 128
np = NP(
decoder=hk.nets.MLP([dim] * 3 + [2]),
latent_encoder=(
hk.nets.MLP([dim] * 3), hk.nets.MLP([dim, dim * 2])
)
)
return np(**kwargs)
key = random.PRNGKey(23)
(x, y), _ = sample_from_sinus_function(key)
neural_process = hk.transform(neural_process)
params = neural_process.init(key, x_context=x, y_context=y, x_target=x)
Why Ramsey
Just as other probabilistic languages are named after researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.
Author
Simon Dirmeier simon.dirmeier @ protonmail com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.