Probability distributions for JAX
Project description
PJAX
Do you just want probability distributions for JAX without all the added baggage of tensorflow-probability
or numpyro
?
Do you have some weird distribution not available in the above or scipy.stats
?
Then PJAX
is for you. Lightweight probability distributions using JAX backend. That's it.
import jax
import jax.numpy as jnp
from pjax import distributions
a = jnp.asarray([4.3, 0.8])
b = jnp.asarray([1.2, 7.3])
dist = distributions.Gamma(a, b)
x = jnp.asarray([0.4, 0.5, 0.6, 0.7])
dist.log_pdf(x)
key = jax.random.PRNGKey(42)
dist.sample(key, shape=(100,))
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pjax-0.0.2.tar.gz
(6.8 kB
view hashes)
Built Distribution
pjax-0.0.2-py3-none-any.whl
(9.8 kB
view hashes)