Library to explain a dataset in natural language.
Project description
Library to explain a dataset in natural language.
| Model | Reference | Description |
|---|---|---|
| iPrompt | 📖, 🗂️, 🔗, 📄 | Generates a human-interpretable prompt that explains patterns in data (Official) |
| Emb-GAM | 📖, 🗂️, 🔗, 📄 | Fit better linear model using an LLM to extract embeddings (Official) |
| D3 | 📖, 🗂️, 🔗, 📄 | Explain the difference between two distributions |
| AutoPrompt | ⠀⠀⠀🗂️, 🔗, 📄 | Find a natural-language prompt using input-gradients (⌛ In progress) |
| (Coming soon!) | ⌛ | We hope to support other interpretable models like RLPrompt, concept bottleneck models, NAMs, and NBDT |
Demo notebooks 📖, Doc 🗂️, Reference code implementation 🔗, Research paper 📄
Quickstart
Installation: pip install imodelsx (or, for more control, clone and install from source)
Demos: see the demo notebooks
iPrompt
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset
# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
print(repr(input_strings[i]), repr(output_strings[i]))
# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
input_strings=input_strings,
output_strings=output_strings,
checkpoint='EleutherAI/gpt-j-6B', # which language model to use
num_learned_tokens=3, # how long of a prompt to learn
n_shots=3, # shots per example
n_epochs=15, # how many epochs to search
verbose=0, # how much to print
llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings
D3 (DescribeDistributionalDifferences)
import imodelsx
hypotheses, hypothesis_scores = imodelsx.explain_datasets_d3(
pos=positive_samples, # List[str] of positive examples
neg=negative_samples, # another List[str]
num_steps=100,
num_folds=2,
batch_size=64,
)
Emb-GAM
from imodelsx import EmbGAMClassifier
import datasets
import numpy as np
# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))
# fit model
m = EmbGAMClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])
# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))
# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
print('\t', k, round(v, 2))
Related work
- imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
- Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
- Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
- Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
- Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
- PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file imodelsx-0.7.tar.gz.
File metadata
- Download URL: imodelsx-0.7.tar.gz
- Upload date:
- Size: 40.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
736872c449ae78369fca358910fbf41b1c20b42de7dd41130bdbdda754de7cd2
|
|
| MD5 |
d28f44e7c41d8d6a49a2d93d38261e58
|
|
| BLAKE2b-256 |
3b41e8f0c654a2c929e300a5e23fe9dfc397d52628d4d7d611b61f9b3e799ca3
|
File details
Details for the file imodelsx-0.7-py3-none-any.whl.
File metadata
- Download URL: imodelsx-0.7-py3-none-any.whl
- Upload date:
- Size: 46.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4257503f7556d2ef99127e88ea659e29dfe2ac77301d41c1b114e87001e5c1b5
|
|
| MD5 |
18b1336ca54eb8630e79244887c1c65c
|
|
| BLAKE2b-256 |
ead7838fc2359c5c3b18c8885e61241ebf75b687a1c0eb627196e54c74f28246
|