Skip to main content

Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more

Project description

Build Status Coverage Status Codestyle Black

Manifold aware pytorch.optim.

Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more.

What is done so far

Work is in progress but you can already use this. Note that API might change in future releases.

Tensors

  • geoopt.ManifoldTensor – just as torch.Tensor with additional manifold keyword argument.

  • geoopt.ManifoldParameter – same as above, recognized in torch.nn.Module.parameters as correctly subclassed.

All above containers have special methods to work with them as with points on a certain manifold

  • .proj_() – inplace projection on the manifold.

  • .proju(u) – project vector u on the tangent space. You need to project all vectors for all methods below.

  • .inner(u, v=None) – inner product at this point for two tangent vectors at this point. The passed vectors are not projected, they are assumed to be already projected.

  • .retr(u, t) – retraction map following vector u for time t

  • .transp(u, t, v, *more) – transport vector v (and possibly more vectors) with direction u for time t

  • .retr_transp(u, t, v, *more) – transport self, vector v (and possibly more vectors) with direction u for time t (returns are plain tensors)

Manifolds

  • geoopt.Euclidean – unconstrained manifold in R with Euclidean metric

  • geoopt.Stiefel – Stiefel manifold on matrices A in R^{n x p} : A^t A=I, n >= p

Optimizers

  • geoopt.optim.RiemannianSGD – a subclass of torch.optim.SGD with the same API

  • geoopt.optim.RiemannianAdam – a subclass of torch.optim.Adam

Samplers

  • geoopt.samplers.RSGLD – Riemannian Stochastic Gradient Langevin Dynamics

  • geoopt.samplers.RHMC – Riemannian Hamiltonian Monte-Carlo

  • geoopt.samplers.SGRHMC – Stochastic Gradient Riemannian Hamiltonian Monte-Carlo

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

geoopt-0.0.1rc2.tar.gz (12.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

geoopt-0.0.1rc2-py3-none-any.whl (20.5 kB view details)

Uploaded Python 3

File details

Details for the file geoopt-0.0.1rc2.tar.gz.

File metadata

  • Download URL: geoopt-0.0.1rc2.tar.gz
  • Upload date:
  • Size: 12.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/39.1.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.6.3

File hashes

Hashes for geoopt-0.0.1rc2.tar.gz
Algorithm Hash digest
SHA256 13b386e7eaf4e0e7710045f78fd51395f12e1803b30a56f23bdf2820c06e519e
MD5 357f3bd42220f6ee74216d4d4ac40e24
BLAKE2b-256 9ed144ae978a14944d77167ec606c488ad85a174c1850ac4f460b70d60043d6d

See more details on using hashes here.

File details

Details for the file geoopt-0.0.1rc2-py3-none-any.whl.

File metadata

  • Download URL: geoopt-0.0.1rc2-py3-none-any.whl
  • Upload date:
  • Size: 20.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.20.1 setuptools/39.1.0 requests-toolbelt/0.8.0 tqdm/4.26.0 CPython/3.6.3

File hashes

Hashes for geoopt-0.0.1rc2-py3-none-any.whl
Algorithm Hash digest
SHA256 9217829f1d6eff6d964c476632173a0998993e6e012062f50531404752c27043
MD5 46fd99ba5a9a9054939fec90106c0016
BLAKE2b-256 0fd9f25529555f986d04e986141c0f1ac033811f6e340625b8de8c8bfaa4ead8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page