Skip to main content

PyTorch-like neural networks in JAX

Project description

Equinox

Equinox is a JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees.

In doing so:

  • We get a PyTorch-like API...
  • ...that's fully compatible with native JAX transformations...
  • ...with no new concepts you have to learn. (It's all just PyTrees.)

The elegance of Equinox is its selling point in a world that already has Haiku, Flax and so on.

(In other words, why should you care? Because Equinox is really simple to learn, and really simple to use.)

Installation

pip install equinox

Requires Python 3.7+ and JAX 0.2.18+.

Documentation

Available at https://docs.kidger.site/equinox.

Quick example

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.numpy.ndarray
    bias: jax.numpy.ndarray

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

and fully compatible with normal JAX operations:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(Also consider starring the project on GitHub.)

Project details


Release history Release notifications | RSS feed

This version

0.1.6

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

equinox-0.1.6.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

equinox-0.1.6-py3-none-any.whl (28.4 kB view details)

Uploaded Python 3

File details

Details for the file equinox-0.1.6.tar.gz.

File metadata

  • Download URL: equinox-0.1.6.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for equinox-0.1.6.tar.gz
Algorithm Hash digest
SHA256 776a7f2909440cf7c7e0048a4fd14b3a247145129869ceabffc24e3997c27f1b
MD5 894101924727660c86c3cefe8ef3fee8
BLAKE2b-256 17a8b4bbfc9e74878b0ba401cd382fb8f986265811a355037456ac75241613fb

See more details on using hashes here.

File details

Details for the file equinox-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: equinox-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 28.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.8.12

File hashes

Hashes for equinox-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 1292c5a95f89d60c828b2d97a4553a127ce017258d62c89f9b792363d4de7511
MD5 4a1daed7fb17d9a114100850c49014f4
BLAKE2b-256 ca68ad396c25c0a181d6650186927f7bb2c9dba6ad5ba4b532b32a7b5e65dbc3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page