Skip to main content

Common backend for JAX or numpy.

Project description

Jumpy is a common backend for JAX or NumPy:

  • A Jumpy function returns a JAX outputs if given a JAX inputs
  • A Jumpy function returns a JAX outputs if jitted
  • Otherwise a jumpy function returns NumPy outputs

Jumpy lets you write framework agnostic code that is easy to debug by running as raw Numpy, but is just as performant as JAX when jitted.

We maintain this repository primarily so to enable writing Gymnasium and PettingZoo wrappers that can be applied to both standard NumPy or hardware accelerated Jax based environments, however this package can be used for many more things.

Installing Jumpy

To install Jumpy from pypi:

python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install brax-jumpy

Alternatively, to install Jumpy from source, clone this repo, cd to it, and then:

python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install -e .

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_jumpy-0.2.0.tar.gz (94.1 kB view hashes)

Uploaded Source

Built Distribution

jax_jumpy-0.2.0-py3-none-any.whl (11.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page