Skip to main content

Library to explain a dataset in natural language.

Project description

Library to explain a dataset in natural language.

📖 demo notebooks

Model Reference Description
iPrompt 📖, 🗂️, 🔗, 📄 Generates a human-interpretable prompt that explains patterns
in data (Official)
Emb-GAM 📖, 🗂️, 🔗, 📄 Fit better linear model using an LLM to extract embeddings (Official)
D3 📖, 🗂️, 🔗, 📄 Explain the difference between two distributions
AutoPrompt ⠀⠀⠀🗂️, 🔗, 📄 Find a natural-language prompt using input-gradients (⌛ In progress)
(Coming soon!) We hope to support other interpretable models like RLPrompt,
concept bottleneck models, NAMs, and NBDT

Demo notebooks 📖, Doc 🗂️, Reference code implementation 🔗, Research paper 📄

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

iPrompt

from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=3, # how long of a prompt to learn
    n_shots=3, # shots per example

    n_epochs=15, # how many epochs to search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings

D3 (DescribeDistributionalDifferences)

import imodelsx
hypotheses, hypothesis_scores = imodelsx.explain_datasets_d3(
    pos=positive_samples, # List[str] of positive examples
    neg=negative_samples, # another List[str]
    num_steps=100,
    num_folds=2,
    batch_size=64,
)

Emb-GAM

from imodelsx import EmbGAMClassifier
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = EmbGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imodelsx-0.7.tar.gz (40.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

imodelsx-0.7-py3-none-any.whl (46.7 kB view details)

Uploaded Python 3

File details

Details for the file imodelsx-0.7.tar.gz.

File metadata

  • Download URL: imodelsx-0.7.tar.gz
  • Upload date:
  • Size: 40.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for imodelsx-0.7.tar.gz
Algorithm Hash digest
SHA256 736872c449ae78369fca358910fbf41b1c20b42de7dd41130bdbdda754de7cd2
MD5 d28f44e7c41d8d6a49a2d93d38261e58
BLAKE2b-256 3b41e8f0c654a2c929e300a5e23fe9dfc397d52628d4d7d611b61f9b3e799ca3

See more details on using hashes here.

File details

Details for the file imodelsx-0.7-py3-none-any.whl.

File metadata

  • Download URL: imodelsx-0.7-py3-none-any.whl
  • Upload date:
  • Size: 46.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for imodelsx-0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 4257503f7556d2ef99127e88ea659e29dfe2ac77301d41c1b114e87001e5c1b5
MD5 18b1336ca54eb8630e79244887c1c65c
BLAKE2b-256 ead7838fc2359c5c3b18c8885e61241ebf75b687a1c0eb627196e54c74f28246

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page