Skip to main content

Nested Sampling in JAX

Project description

Python PyPI Documentation Status

Main Status: Workflow name

Develop Status: Workflow name

JAXNS

Mission: To make nested sampling faster, easier, and more powerful

What is it?

JAXNS is:

  1. a probabilistic programming framework using nested sampling as the engine;
  2. coded in JAX in a manner that allows lowering the entire inference algorithm to XLA primitives, which are JIT-compiled for high performance;
  3. continuously improving on its mission of making nested sampling faster, easier, and more powerful; and
  4. citable, and you can read an (old) pre-print here: (https://arxiv.org/abs/2012.15286).

Documentation

You can read the documentation here.

Install

Notes:

  1. JAXNS requires >= Python 3.8.
  2. It is always highly recommended to use a unique virtual environment for each project. To use miniconda, have it installed, and run
# To create a new env, if necessary
conda create -n jaxns_py python=3.11
conda activate jaxns_py

For end users

Install directly from PyPi,

pip install jaxns

For development

Clone repo git clone https://www.github.com/JoshuaAlbert/jaxns.git, and install:

cd jaxns
pip install -r requirements.txt
pip install -r requirements-tests.txt
pip install -r requirements-examples.txt
pip install .

Getting help and contributing examples

Do you have a neat Bayesian problem, and want to solve it with JAXNS? I'm really encourage anyone in either the scientific community or industry to get involved and join the discussion forum. Please use the github discussion forum for getting help, or contributing examples/neat use cases.

Quick start

Checkout the examples here.

Caveats

The caveat is that you need to be able to define your likelihood function with JAX. This is usually no big deal because JAX is just a replacement for NumPy and many likelihoods can be expressed such. If you're unfamiliar, take a quick tour of JAX (https://jax.readthedocs.io/en/latest/notebooks/quickstart.html).

Speed test comparison with other nested sampling packages

JAXNS is really fast because it uses JAX. JAXNS is much faster than PolyChord, MultiNEST, and dynesty, typically achieving two to three orders of magnitude improvement in speed on cheap likelihood evaluations. This is shown in (https://arxiv.org/abs/2012.15286). With regards to how efficiently JAXNS used likelihood evaluations, JAXNS prizes exactness over efficiency, however since it employs an adaptive strategy, users can control efficiency by controlling some precision parameters.

Change Log

5 Oct, 2023 -- JAXNS 2.2.6 released. Minor update to evidence maximisation.

3 Oct, 2023 -- JAXNS 2.2.5 released. Parametrised priors, and evidence maximisation added.

24 Sept, 2023 -- JAXNS 2.2.4 released. Add marginalising from saved U samples.

28 July, 2023 -- JAXNS 2.2.3 released. Bug fix for singular priors.

26 June, 2023 -- JAXNS 2.2.1 released. Multi-ellipsoidal sampler added back in. Adaptive refinement disabled, as a bias has been detected in it.

15 June, 2023 -- JAXNS 2.2.0 released. Added support to allow TFP bijectors to defined transformed distributions. Other minor improvements.

15 April, 2023 -- JAXNS 2.1.0 released. pmap used on outer-most loops allowing efficient device-device communication during parallel runs.

8 March, 2023 -- JAXNS 2.0.1 released. Changed how we're doing annotations to support python 3.8 again.

3 January, 2023 -- JAXNS 2.0 released. Complete overhaul of components. New way to build models.

5 August, 2022 -- JAXNS 1.1.1 released. Pytree shaped priors.

2 June, 2022 -- JAXNS 1.1.0 released. Dynamic sampling takes advantage of adaptive refinement. Parallelisation. Bayesian opt and global opt modules.

30 May, 2022 -- JAXNS 1.0.1 released. Improvements to speed, parallelisation, and structure of code.

9 April, 2022 -- JAXNS 1.0.0 released. Parallel sampling, dynamic search, and adaptive refinement. Global optimiser released.

2 Jun, 2021 -- JAXNS 0.0.7 released.

13 May, 2021 -- JAXNS 0.0.6 released.

8 Mar, 2021 -- JAXNS 0.0.5 released.

8 Mar, 2021 -- JAXNS 0.0.4 released.

7 Mar, 2021 -- JAXNS 0.0.3 released.

28 Feb, 2021 -- JAXNS 0.0.2 released.

28 Feb, 2021 -- JAXNS 0.0.1 released.

1 January, 2021 -- Paper submitted

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxns-2.2.6.tar.gz (79.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxns-2.2.6-py3-none-any.whl (94.5 kB view details)

Uploaded Python 3

File details

Details for the file jaxns-2.2.6.tar.gz.

File metadata

  • Download URL: jaxns-2.2.6.tar.gz
  • Upload date:
  • Size: 79.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for jaxns-2.2.6.tar.gz
Algorithm Hash digest
SHA256 f64d44a83de9feeb7be2eeb5185d1fd084981f8ce4885d145de77a85ed9b3518
MD5 133e5f90165f0162c33c0724f18e9ba8
BLAKE2b-256 73b86c3dde31c86ca5829edd740beaa50351f31b8485b78a6d17286c7a602a0b

See more details on using hashes here.

File details

Details for the file jaxns-2.2.6-py3-none-any.whl.

File metadata

  • Download URL: jaxns-2.2.6-py3-none-any.whl
  • Upload date:
  • Size: 94.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for jaxns-2.2.6-py3-none-any.whl
Algorithm Hash digest
SHA256 6f7e1446736025cf55ead7867338f66d4392b54426881e26d2bbf49594fbf41b
MD5 818c714875575a8c56e360de04ad214e
BLAKE2b-256 e0a9dfa1ea8289755dccdccb4cb315eefbb86b970caffba425c471b178b644ac

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page