Implicit and competitive differentiation in JAX.
Project description
fax: fixed-point jax
Implicit and competitive differentiation in JAX.
Our "competitive differentiation" approach uses Competitive Gradient Descent to solve the equality-constrained nonlinear program associated with the fixed-point problem. A standalone implementation of CGD is provided under fax/competitive/cga.py and the equality-constrained solver derived from it can be accessed via fax.constrained.cga_lagrange_min
or fax.constrained.cga_ecp
. An implementation of implicit differentiation based on Christianson's two-phases reverse accumulation algorithm can also be obtained with the function fax.implicit.two_phase_solver
.
See fax/constrained/constrained_test.py for examples. Please note that the API is subject to change.
References
Citing competitive differentiation:
@inproceedings{bacon2019optrl,
author={Pierre-Luc Bacon, Florian Schaefer, Clement Gehring, Animashree Anandkumar, Emma Brunskill},
title={A Lagrangian Method for Inverse Problems in Reinforcement Learning},
booktitle={NeurIPS Optimization Foundations for Reinforcement Learning Workshop},
year={2019},
url={http://lis.csail.mit.edu/pubs/bacon-optrl-2019.pdf},
keywords={Optimization, Reinforcement Learning, Lagrangian}
}
Citing this repo:
@misc{gehring2019fax,
author = {Clement Gehring, Pierre-Luc Bacon, Florian Schaefer},
title = {{FAX: differentiating fixed point problems in JAX}},
note = {Available at: https://github.com/gehring/fax},
year = {2019}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for jax_fixedpoint_test_manueldelverme-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e89897732ec4e907162345f26f280f02d7b8f2987357609e03be47bc62585e1 |
|
MD5 | b8f6b5114b172af3514b5698a1737706 |
|
BLAKE2b-256 | 9e9100508202ccd7e1db4866aa42e5edea077036987b9e69b02ea4f9cbddf5a8 |