Dataclasses + JAX
Project description
jax_dataclasses
Library for using dataclasses as JAX PyTrees.
Key features:
- PyTree registration; automatic generation of flatten/unflatten ops.
- Static analysis-friendly. Works out of the box with tools like
mypy
andjedi
. - Support for serialization via
flax.serialization
.
Usage
Basic
jax_dataclasses
is meant to be a drop-in replacement for
dataclasses.dataclass
:
jax_dataclasses.dataclass
has the same interface asdataclasses.dataclass
, but also register a class as a PyTree.jax_dataclasses.static_field
has the same interface asdataclasses.field
, but will also mark the field as static. In a PyTree node, static fields are treated as part of the treedef instead of as a child of the node.
We also provide several aliases:
jax_dataclasses.[field, asdict, astuples, is_dataclass, replace]
are all
identical to their counterparts in the standard dataclasses library.
Mutations
All dataclasses are automatically marked as frozen and thus immutable. We do, however, provide an interface that will (a) make a copy of a PyTree and (b) return a context in which any of that copy's contained dataclasses are temporarily mutable:
from jax import numpy as jnp
import jax_dataclasses
@jax_dataclasses.dataclass
class Node:
child: jnp.ndarray
obj = Node(child=jnp.zeros(3))
with jax_dataclasses.copy_and_mutate(obj) as obj_updated:
# Make mutations to the dataclass.
# Also does input validation: if the treedef of `obj` and `obj_updated` don't
# match, an AssertionError will be raised.
obj_updated.child = jnp.ones(3)
print(obj)
print(obj_updated)
Motivation
For compatibility with function transformations in JAX (jit, grad, vmap, etc), arguments and return values must all be PyTree containers. Dataclasses, by default, are not.
A few great solutions exist for automatically integrating dataclass-style
objects into PyTree structures, notably
chex.dataclass
and
flax.struct
. This library implements another
one.
Why not use chex.dataclass
?
chex.dataclass
is handy and lightweight, but currently lacks support for:
- Static fields: parameters that are either non-differentiable or simply not arrays.
- Serialization using
flax.serialization
. This is really handy when parameters needed to be saved to disk!
Why not use flax.struct
?
flax.struct
addresses the two points above, but both it and chex.dataclass
:
- Lack support for static analysis and type-checking. Static analysis for
libraries like
dataclasses
andattrs
tends to rely on tooling-specific custom plugins, which doesn't exist for eitherchex.dataclass
orflax.struct
. - Make modifying deeply nested dataclasses fairly frustrating. Both introduce a
.replace(self, ...)
method to dataclasses that's a bit more convenient than the traditionaldataclasses.replace(obj, ...)
API, but this becomes really cumbersome to use when dataclasses are nested. Fixing this is the goal ofjax_dataclasses.copy_and_mutate()
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for jax_dataclasses-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6a72ed945232587ff7c38c12411250bb716e09b9a5e4fb9c31cc3f8816a0f513 |
|
MD5 | e1d9d11f0fce44fa2c3e590cc5c4856c |
|
BLAKE2b-256 | 6d4b409328787a8eabd5219d5af67badb8de0e3b854091c8f36e312241a06f82 |