Skip to main content

A plugin to scikit-learn for quantum-classical hybrid solving.

Project description

PyPI CircleCI

D-Wave scikit-learn Plugin

This package provides a scikit-learn transformer for feature selection using a quantum-classical hybrid solver.

This plugin makes use of a Leap™ quantum-classical hybrid solver. Developers can get started by signing up for the Leap quantum cloud service for free. Those seeking a more collaborative approach and assistance with building a production application can reach out to D-Wave directly and also explore the feature selection offering in AWS Marketplace.

The package's main class, SelectFromQuadraticModel, can be used in any existing sklearn pipeline. For an introduction to hybrid methods for feature selection, see the Feature Selection for CQM.

Examples

Basic Usage

A minimal example of using the plugin to select 20 of 30 features of an sklearn dataset:

>>> from sklearn.datasets import load_breast_cancer
>>> from dwave.plugins.sklearn import SelectFromQuadraticModel
... 
>>> X, y = load_breast_cancer(return_X_y=True)
>>> X.shape
(569, 30)
>>> # solver can also be equal to "cqm"
>>> X_new = SelectFromQuadraticModel(num_features=20, solver="nl").fit_transform(X, y)
>>> X_new.shape
(569, 20)

For large problems, the default runtime may be insufficient. You can use the CQM solver's time_limit or Nonlinear (NL) solver's time_limit method to find the minimum accepted runtime for your problem; alternatively, simply submit as above and check the returned error message for the required runtime.

The feature selector can be re-instantiated with a longer time limit.

>>> # solver can also be equal to "nl"
>>> X_new = SelectFromQuadraticModel(num_features=20, time_limit=200, solver="cqm").fit_transform(X, y)

Tuning

You can use SelectFromQuadraticModel with scikit-learn's hyper-parameter optimizers.

For example, the number of features can be tuned using a grid search. Please note that this will submit many problems to the hybrid solver.

>>> import numpy as np
...
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.ensemble import RandomForestClassifier
>>> from sklearn.model_selection import GridSearchCV
>>> from sklearn.pipeline import Pipeline
>>> from dwave.plugins.sklearn import SelectFromQuadraticModel
...
>>> X, y = load_breast_cancer(return_X_y=True)
...
>>> num_features = X.shape[1]
>>> searchspace = np.linspace(1, num_features, num=5, dtype=int, endpoint=True)
...
>>> # solver can also be equal to "cqm"
>>> pipe = Pipeline([
>>>   ('feature_selection', SelectFromQuadraticModel(solver="nl")),
>>>   ('classification', RandomForestClassifier())
>>> ])
...
>>> clf = GridSearchCV(pipe, param_grid=dict(feature_selection__num_features=searchspace))
>>> search = clf.fit(X, y)
>>> print(search.best_params_)
{'feature_selection__num_features': 22}

Installation

To install the core package:

pip install dwave-scikit-learn-plugin

License

Released under the Apache License 2.0

Contributing

Ocean's contributing guide has guidelines for contributing to Ocean packages.

Release Notes

dwave-scikit-learn-plugin makes use of reno to manage its release notes.

When making a contribution to dwave-scikit-learn-plugin that will affect users, create a new release note file by running

reno new your-short-descriptor-here

You can then edit the file created under releasenotes/notes/. Remove any sections not relevant to your changes. Commit the file along with your changes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dwave_scikit_learn_plugin-0.2.0.tar.gz (20.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dwave_scikit_learn_plugin-0.2.0-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file dwave_scikit_learn_plugin-0.2.0.tar.gz.

File metadata

File hashes

Hashes for dwave_scikit_learn_plugin-0.2.0.tar.gz
Algorithm Hash digest
SHA256 1976acb97c740df46e50d2452ed1a04ded892c8bec77ad18901b69ec6117d869
MD5 0e0ec5c0a5f77f0c6efffb441c57598f
BLAKE2b-256 04084eb091c429b4ecdf92b7b435df209938a70c9aecda1d86e0bc216422d3b6

See more details on using hashes here.

File details

Details for the file dwave_scikit_learn_plugin-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for dwave_scikit_learn_plugin-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5b63e2c0aa8e1940c9ce56bbaf259dd2b1fa46b8b82fd5448175a0942ade9832
MD5 158beb842e6237511c9ed2347c33749e
BLAKE2b-256 910f7518d9973ebcf1c50fe56ba2ece1b1937566a39f2ee9ff5b1a0b99b4e08f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page