An easy-to-use api for the closed-form continuous models in tensorflow and pytorch.
Project description
Closed-form Continuous-time Models
Closed-form Continuous-time Neural Networks (CfCs) are powerful sequential neural information processing units.
Paper Open Access: https://www.nature.com/articles/s42256-022-00556-7
Arxiv: https://arxiv.org/abs/2106.13898
Requirements
- Python3.6 or newer
- Tensorflow 2.4 or newer
- Pytorch
- Numpy
For a fresh anaconda environment with the required dependencies:
conda env create --file environment.yml
conda activate cfc
Usage
Example
from cfc_model.dense_model import SequentialModel
X = np.array([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0],
[1, 0, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 0, 1, 0]])
y = np.array([0, 0, 1, 1, 1, 0, 1, 1])
model = SequentialModel()
model.fit(X, y)
y_pred = model.predict([1, 1, 0, 1]) # y_pred equals 0
The following configuration states can be used
no_gate
Runs the CfC without the (1-sigmoid) partminimal
Runs the CfC direct solutionuse_ltc
Runs an LTC with a semi-implicit ODE solver instead of a CfCuse_mixed
Mixes the CfC's RNN-state with a LSTM to avoid vanishing gradients
If none of these flags are provided, the full CfC model is used
Example
from cfc_model.dense_model import SequentialModel
X = np.array([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0],
[1, 0, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 0, 1, 0]])
y = np.array([0, 0, 1, 1, 1, 0, 1, 1])
model = SequentialModel()
# Runs an LTC with a semi-implicit ODE solver instead of a CfC
config = {"use_ltc": True}
model.fit(X, y, config=config)
y_pred = model.predict([1, 1, 0, 1]) # y_pred equals 0
Cite
title = {Closed-form continuous-time neural networks},
journal = {Nature Machine Intelligence},
author = {Hasani, Ramin and Lechner, Mathias and Amini, Alexander and Liebenwein, Lucas and Ray, Aaron and Tschaikowski, Max and Teschl, Gerald and Rus, Daniela},
issn = {2522-5839},
month = nov,
year = {2022},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cfc_model-1.0.5.tar.gz
(23.0 MB
view hashes)
Built Distribution
cfc_model-1.0.5-py3-none-any.whl
(23.0 MB
view hashes)
Close
Hashes for cfc_model-1.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1e3ae4c52bbf78d34945186cdf524ed35ae4e295d2a381bc1467e0808ad1fadb |
|
MD5 | 24b6e84447681a0c06040bc126c50907 |
|
BLAKE2b-256 | 8e2d31d03ea9ac73515cc9a810b9a91afaaec02c180a110f82cfc8e9eb71baba |