Tools for JAX.
Project description
Tools for JAX
This repository implements a variety of tools for the differential programming library JAX.
Major components
Tjax’s major components are:
A dataclass decorator
dataclassthat facilitates defining structured JAX objects (so-called “pytrees”), which benefits from:the ability to mark fields as static (not available in chex.dataclass), and
a display method that produces formatted text according to the tree structure.
A shim for the gradient transformation library optax that supports:
easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
gradient transformation objects that can be passed dynamically to jitted functions, and
generic type annotations.
A pretty printer
print_genericfor aggregate and vector types, including dataclasses. (See display.) It features:support for traced values,
colorized tree output for aggregate structures, and
formatted tabular output for arrays (or statistics when there’s no room for tabular output).
Minor components
Tjax also includes:
Versions of
custom_vjpandcustom_jvpthat support being used on methods:custom_vjp_methodandcustom_vjp_method(See shims.)Tools for working with cotangents. (See cotangent_tools.)
JAX tree registration for NetworkX graph types. (See graph.)
Leaky integration
leaky_integrateand Ornstein-Uhlenbeck process iterationdiffused_leaky_integrate. (See leaky_integral.)An improved version of
jax.tree_util.Partial. (See partial.)A testing function
assert_tree_allclosethat automatically produces testing code. And, a related functiontree_allclose. (See testing.)Basic tools like
divide_where. (See tools.)
Contribution guidelines
The implementation should be consistent with the surrounding style, be type annotated, and pass the linters below.
To run tests: pytest
There are a few tools to clean and check the source:
ruff checkpyrightmypyisort .pylint tjax tests
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tjax-1.4.1.tar.gz.
File metadata
- Download URL: tjax-1.4.1.tar.gz
- Upload date:
- Size: 149.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8cd5049cd8b650b902b799002868b8be8a888cb87770793ec8a1d5dd62b4616d
|
|
| MD5 |
5e73e4aaac4a1ae8342d7fc407d88698
|
|
| BLAKE2b-256 |
0c3f66a1f78e4dec7eea86d2f958d1a2999ac4142acbbb0f9b1fc72fcb6d4bdb
|
File details
Details for the file tjax-1.4.1-py3-none-any.whl.
File metadata
- Download URL: tjax-1.4.1-py3-none-any.whl
- Upload date:
- Size: 46.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.19
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ecf46293e3a236bee5111eab7dd380b9954800461dec1ba81d1145edf768fac2
|
|
| MD5 |
6a7b87b8b5880887ac95fa72b27d4ba2
|
|
| BLAKE2b-256 |
e511ce9e681da06d0d8db95fcea04f812b37842b364e401183024fe20aa66287
|