Equivariant convolutional neural networks for the group E(3) of 3 dimensional rotations, translations, and mirrors.
Project description
:construction: :construction: :construction: Disclamier: This is a work in progress. No part of the library can be considered stable.
e3nn-jax
What is different from the pytorch version?
- no more
shared_weightsandinternal_weightsinTensorProduct. Extensive use ofjax.vmapinstead (see example below) - support of python structure
IrrepsDatathat contains a contiguous version of the data and a list ofjnp.arrayfor the data. This allows to avoid unnecessaryjnp.concatenantefollowed by indexing to reverse the concatenation - support of
Nonein the list ofjnp.arrayto avoid unnecessary computation with zeros
Example
Example with the Irreps class.
This class specifies a direct sum of irreducible representations.
It does not contain any actual data. It is use to specify the "type" of the data under rotation.
from e3nn_jax import Irreps
irreps = Irreps("2x0e + 3x1e") # 2 even scalars and 3 even vectors
irreps = irreps + irreps # 2x0e+3x1e+2x0e+3x1e
irreps.D_from_angles(alpha=1.57, beta=1.57, gamma=0.0) # 22x22 matrix
It also includes the parity.
irreps = Irreps("0e + 0o") # an even scalar and an odd scalar
irreps.D_from_angles(alpha=0.0, beta=0.0, gamma=0.0, k=1) # the matrix that applies parity
# DeviceArray([[ 1., 0.],
# [ 0., -1.]], dtype=float32)
IrrepsData contains both the irreps and the data.
Here is the example of the tensor product of the two vectors.
from e3nn_jax import full_tensor_product, IrrepsData
out = full_tensor_product(
IrrepsData.from_contiguous("1o", jnp.array([2.0, 0.0, 0.0])),
IrrepsData.from_contiguous("1o", jnp.array([0.0, 2.0, 0.0]))
)
# out is of type `IrrepsData` and contains the following fields:
out.irreps
# 1x0e+1x1e+1x2e
out.contiguous
# DeviceArray([0. , 0. , 0. , 2.83, 0. , 2.83, 0. , 0. , 0. ], dtype=float32)
out.list
# [DeviceArray([[0.]], dtype=float32),
# DeviceArray([[0. , 0. , 2.83]], dtype=float32),
# DeviceArray([[0. , 2.83, 0. , 0. , 0. ]], dtype=float32)]
The two fields contiguous and list contain the same information under different forms.
This is not a performence issue, we rely on jax.jit to optimize the code and get rid of the unused operations.
Shared weights
torch version (e3nn repo):
f = o3.FullyConnectedTensorProduct(irreps1, irreps2, irreps3, shared_weights=True)
f(x, y)
jax version (this repo):
tp = FunctionalFullyConnectedTensorProduct(irreps1, irreps2, irreps3)
w = [jax.random.normal(key, i.path_shape) for i in tp.instructions if i.has_weight]
f = jax.vmap(tp.left_right, (None, 0, 0), 0)
f = jax.jit(f)
f(w, x, y)
Batch weights
torch version:
f = o3.FullyConnectedTensorProduct(irreps1, irreps2, irreps3, shared_weights=False)
f(x, y, w)
jax version:
tp = FunctionalFullyConnectedTensorProduct(irreps1, irreps2, irreps3)
w = [jax.random.normal(key, (10,) + i.path_shape) for i in tp.instructions if i.has_weight]
f = jax.vmap(tp.left_right, (0, 0, 0), 0)
f = jax.jit(f)
f(w, x, y)
Extra channel index
torch version not implemented
jax version just needs an extra bunch of vmap calls:
def compose(f, g):
return lambda *x: g(f(*x))
def tp_extra_channels(irreps_in1, irreps_in2, irreps_out):
tp = FunctionalFullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out)
f = tp.left_right
f = jax.vmap(f, (0, None, None), 0) # channel_out
f = jax.vmap(f, (0, None, 0), 0) # channel_in2
f = jax.vmap(f, (0, 0, None), 0) # channel_in1
f = compose(f, lambda z: jnp.sum(z, (0, 1)) / jnp.sqrt(z.shape[0] * z.shape[1]))
tp.left_right = f
return tp
tp = tp_extra_channels(irreps, irreps, irreps)
f = jax.vmap(tp.left_right, (None, 0, 0), 0) # batch
f = jax.jit(f)
w = [jax.random.normal(k, (16, 32, 48) + i.path_shape) for i in tp.instructions if i.has_weight]
# x1.shape = (batch, ch_in1, irreps_in1)
# x2.shape = (batch, ch_in2, irreps_in2)
z = f(w, x1, x2)
# z.shape = (batch, ch_out, irreps_out)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file e3nn_jax-0.4.0.tar.gz.
File metadata
- Download URL: e3nn_jax-0.4.0.tar.gz
- Upload date:
- Size: 100.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.63.0 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
40bbbb4f215c7cd4f98243b9a2645ee8f93bd5d4be56f708dfdc1ac475357ea7
|
|
| MD5 |
0c8a79e6b510a2bd937e6963ea1e4c8c
|
|
| BLAKE2b-256 |
148e5363e8865e62aa408d075c09557b54e293e32ad2c98c4767a1d54e4a29fe
|
File details
Details for the file e3nn_jax-0.4.0-py3-none-any.whl.
File metadata
- Download URL: e3nn_jax-0.4.0-py3-none-any.whl
- Upload date:
- Size: 107.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.9 tqdm/4.63.0 importlib-metadata/4.11.3 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.8.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1446d756ab0951ded5c1c8c57f446c81ff46f401dc9d1735ee169911055bbb97
|
|
| MD5 |
6ee90866674af0c4f988281ad51adbb5
|
|
| BLAKE2b-256 |
0f0998dc0b5a5aa83759f77333ddaeb767efff35b09864bd0c86fa1698bf766f
|