Skip to main content

Boax is a Bayesian Optimization library for JAX.

Project description

Boax: A Bayesian Optimization library for JAX.

Overview | Installation | Getting Started | Documentation

Boax is currently in early alpha and under active development!

Overview

Boax is a composable library of core components for Bayesian Optimization that is designed for flexibility. It comes with a low-level interfaces for:

  • Fitting a surrogate model to data (boax.prediction):
    • Kernels
    • Mean Functions
    • Surrogate Models
    • Objectives
  • Constructing and optimizing acquisition functions (boax.optimization):
    • Acquisition Functions
    • Constraint Functions
    • Maximizers

Installation

You can install the latest released version of Boax from PyPI via:

pip install boax

or you can install the latest development version from GitHub:

pip install git+https://github.com/Lando-L/boax.git

Getting Started

Here is a quick start example of the two main compoments that form the Bayesian optimization loop. For more details check out the docs.

  1. Create a dataset from a noisy sinusoid.
from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import value_and_grad

import optax

from boax.prediction import kernels, means, models, objectives
from boax.optimization import acquisitions, maximizers

bounds = jnp.array([[-3.0, 3.0]])

def objective(x):
  return jnp.sin(4 * x[..., 0]) + jnp.cos(2 * x[..., 0])

sample_key, noise_key, maximizer_key = random.split(random.key(0), 3)
x_train = random.uniform(sample_key, minval=bounds[:, 0], maxval=bounds[:, 1], shape=(10, 1))
y_train = objective(x_train) + 0.3 * random.normal(noise_key, shape=(10,))
  1. Fit a Gaussian Process surrogate model to the training dataset.
def prior(amplitude, length_scale, noise):
  return models.gaussian_process(
    means.zero(),
    kernels.scaled(kernels.rbf(nn.softplus(length_scale)), nn.softplus(amplitude)),
    nn.softplus(noise),
  )

def target_log_prob(params):
  y_hat = prior(**params)(x_train)
  return -objectives.exact_marginal_log_likelihood()(y_hat, y_train)

params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

optimizer = optax.adam(0.01)

def train_step(state, iteration):
  loss, grads = value_and_grad(target_log_prob)(state[0])
  updates, opt_state = optimizer.update(grads, state[1])
  params = optax.apply_updates(state[0], updates)

  return (params, opt_state), loss

(next_params, next_opt_state), history = lax.scan(
  jit(train_step),
  (params, optimizer.init(params)),
  jnp.arange(500)
)
  1. Construct and optimize an UCB acquisition function.
def posterior(amplitude, length_scale, noise):
  return models.gaussian_process_regression(
    x_train,
    y_train,
    means.zero(),
    kernels.scaled(kernels.rbf(nn.softplus(length_scale)), nn.softplus(amplitude)),
    nn.softplus(noise),
  )

surrogate = posterior(**next_params)
acqf = acquisitions.upper_confidence_bound(surrogate, beta=2.0)
maximizer = maximizers.bfgs(acqf, bounds, q=1, num_restarts=25, num_raw_samples=500)

candidates, values = maximizer(maximizer_key)

Citing Boax

To cite Boax please use the citation:

@software{boax2023github,
  author = {Lando L{\"o}per},
  title = {{B}oax: A Bayesian Optimization library for {JAX}},
  url = {https://github.com/Lando-L/boax},
  version = {0.0.3},
  year = {2023},
}

In the above bibtex entry, the version number is intended to be that from boax/version.py, and the year corresponds to the project's open-source release.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

boax-0.0.3.tar.gz (23.7 kB view hashes)

Uploaded Source

Built Distribution

boax-0.0.3-py3-none-any.whl (56.0 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page