Skip to main content

An implementation of the BatchBALD algorithm

Reason this release was yanked:

Bug in SampledJointEntropy (uses the full tensor instead of the chunked one).

Project description

BatchBALD Redux

Clean reimplementation of "BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning"

For an introduction & more information, see http://batchbald.ml/. The paper can be found at http://arxiv.org/abs/1906.08158.

The original implementation used in the paper is available at https://github.com/BlackHC/BatchBALD.

We are grateful for fastai's nbdev which is powering this package.

Install

pip install batchbald_redux

Motivation

BatchBALD is an algorithm and acquisition function for Active Learning in a Bayesian setting using BNNs and MC dropout.

The aquisition function is the mutual information between the joint of a candidate batch and the model parameters $\omega$:

{% raw %} $$a_{\text{BatchBALD}}((y_b)_B) = I[(y_b)_B;\omega]$$ {% endraw %}

The best candidate batch is one that maximizes this acquisition function.

In the paper, we show that this function satisfies sub-modularity, which provides us an optimality guarantee for a greedy algorithm. The candidate batch is selected using greedy expansion.

Joint entropies are hard to estimate and, for everything to work, one also has to use consistent MC dropout, which keeps a set of dropout masks fixed while scoring the pool set.

To aid reproducibility and baseline reproduction, we provide this simpler and clearer reimplementation.

How to use

We provide a simple example experiment that uses this package here.

To get a candidate batch using BatchBALD, we provide a simple API in batchbald_redux.batchbald:

from nbdev.showdoc import *
from batchbald_redux.batchbald import get_batchbald_batch

show_doc(get_batchbald_batch)

get_batchbald_batch[source]

get_batchbald_batch(logits_N_K_C:Tensor, batch_size:int, num_samples:int, dtype=None, device=None)

We also provide a simple implementation of consistent MC dropout in batchbald_redux.consistent_mc_dropout.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

batchbald_redux-1.0.0.tar.gz (17.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

batchbald_redux-1.0.0-py3-none-any.whl (16.7 kB view details)

Uploaded Python 3

File details

Details for the file batchbald_redux-1.0.0.tar.gz.

File metadata

  • Download URL: batchbald_redux-1.0.0.tar.gz
  • Upload date:
  • Size: 17.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.7

File hashes

Hashes for batchbald_redux-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7747ad5b06415fb4bd889cd71ae4a6686e52fb3e52bb4f4b3d13f32ccebd2f01
MD5 dc5126f94f68b47887e9aebea7a751ea
BLAKE2b-256 23e9f131734d0b090edb0d41e924843936c1c733f93bf3e9f22e31709d11d36a

See more details on using hashes here.

File details

Details for the file batchbald_redux-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: batchbald_redux-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 16.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200330 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.7

File hashes

Hashes for batchbald_redux-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 81ea4d1a30ccc59cf51e9f126329f5caeac97a26254cd7efd83a94cf8751ec65
MD5 9739af15e8b9fc78ee05ffd767665dd8
BLAKE2b-256 af930671653b48fab62634bdee9dc69e1acdd80a9aa8e05c99acbfd26aadf246

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page