Skip to main content

Tools for JAX.

Project description

PyPI - Version EffVer Versioning SPEC-0 Ruff PyPI - Python Version

Tools for JAX

This repository implements a variety of tools for the differential programming library JAX.

Major components

Tjax’s major components are:

  • A dataclass decorator dataclass that facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:

    • the ability to mark fields as static (not available in chex.dataclass), and

    • a display method that produces formatted text according to the tree structure.

  • A shim for the gradient transformation library optax that supports:

    • easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,

    • gradient transformation objects that can be passed dynamically to jitted functions, and

    • generic type annotations.

  • A pretty printer print_generic for aggregate and vector types, including dataclasses. (See display.) It features:

    • support for traced values,

    • colorized tree output for aggregate structures, and

    • formatted tabular output for arrays (or statistics when there’s no room for tabular output).

Minor components

Tjax also includes:

  • Versions of custom_vjp and custom_jvp that support being used on methods: custom_vjp_method and custom_vjp_method (See shims.)

  • Tools for working with cotangents. (See cotangent_tools.)

  • JAX tree registration for NetworkX graph types. (See graph.)

  • Leaky integration leaky_integrate and Ornstein-Uhlenbeck process iteration diffused_leaky_integrate. (See leaky_integral.)

  • An improved version of jax.tree_util.Partial. (See partial.)

  • A testing function assert_tree_allclose that automatically produces testing code. And, a related function tree_allclose. (See testing.)

  • Basic tools like divide_where. (See tools.)

Contribution guidelines

The implementation should be consistent with the surrounding style, be type annotated, and pass the linters below.

To run tests: pytest

There are a few tools to clean and check the source:

  • ruff check

  • pyright

  • mypy

  • isort .

  • pylint tjax tests

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tjax-1.4.1.tar.gz (149.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tjax-1.4.1-py3-none-any.whl (46.8 kB view details)

Uploaded Python 3

File details

Details for the file tjax-1.4.1.tar.gz.

File metadata

  • Download URL: tjax-1.4.1.tar.gz
  • Upload date:
  • Size: 149.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.19

File hashes

Hashes for tjax-1.4.1.tar.gz
Algorithm Hash digest
SHA256 8cd5049cd8b650b902b799002868b8be8a888cb87770793ec8a1d5dd62b4616d
MD5 5e73e4aaac4a1ae8342d7fc407d88698
BLAKE2b-256 0c3f66a1f78e4dec7eea86d2f958d1a2999ac4142acbbb0f9b1fc72fcb6d4bdb

See more details on using hashes here.

File details

Details for the file tjax-1.4.1-py3-none-any.whl.

File metadata

  • Download URL: tjax-1.4.1-py3-none-any.whl
  • Upload date:
  • Size: 46.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.19

File hashes

Hashes for tjax-1.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ecf46293e3a236bee5111eab7dd380b9954800461dec1ba81d1145edf768fac2
MD5 6a7b87b8b5880887ac95fa72b27d4ba2
BLAKE2b-256 e511ce9e681da06d0d8db95fcea04f812b37842b364e401183024fe20aa66287

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page