Skip to main content

Library to explain *a dataset* in natural language.

Project description

imodelsX: interpretability for teXt

Interpretable linear model that leverages a pre-trained language model to better learn interactions. One-line fit function.

📚 sklearn-friendly api📖 demo notebook

Official code for using / reproducing Emb-GAM from the paper "Emb-GAM: an interpretable and efficient predictor using pre-trained language models" (singh & gao, 2022). Emb-GAM uses a pre-trained language model to extract features from text data then combines them in order to extract out a simple, linear model.

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Usage example (see api or demo notebook for more details):

from embgam import EmbGAMClassifier
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = EmbGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Docs

Abstract: Deep learning models have achieved impressive prediction performance but often sacrifice interpretability, a critical consideration in high-stakes domains such as healthcare or policymaking. In contrast, generalized additive models (GAMs) can maintain interpretability but often suffer from poor prediction performance due to their inability to effectively capture feature interactions. In this work, we aim to bridge this gap by using pre-trained large-language models to extract embeddings for each input before learning a linear model in the embedding space. The final model (which we call Emb-GAM) is a transparent, linear function of its input features and feature interactions. Leveraging the language model allows \methods to learn far fewer linear coefficients, model larger interactions, and generalize well to novel inputs (e.g. unseen ngrams in text). Across a variety of natural-language-processing datasets, Emb-GAM achieves strong prediction performance without sacrificing interpretability.
  • the main api requires simply importing embgam.EmbGAMClassifier or embgam.EmbGAMRegressor
  • the experiments and scripts folder contains hyperparameters for running sweeps contained in the paper
  • the notebooks folder contains notebooks for analyzing the outputs + making figures
  • stored outputs after running all experiments are available in this gdrive folder

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

If this package is useful for you, please cite the following!

@article{singh2022embgam,
  title = {Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
  author = {Singh, Chandan and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2209.11799},
  doi = {10.48550/arxiv.2209.11799},
  url = {https://arxiv.org/abs/2209.11799},
  year = {2022},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imodelsx-0.1.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

imodelsx-0.1-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file imodelsx-0.1.tar.gz.

File metadata

  • Download URL: imodelsx-0.1.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for imodelsx-0.1.tar.gz
Algorithm Hash digest
SHA256 8a1e8c7a1080f3f275c1b60210dfccd01a2d7c5b53eac5612f87a5eecef1d9a0
MD5 e01c930f83196a8973df8d7c4fc0852f
BLAKE2b-256 6e48444a2fa2ee16a06260d9585b21cb3747f2048f30af16ddb923863bde3926

See more details on using hashes here.

File details

Details for the file imodelsx-0.1-py3-none-any.whl.

File metadata

  • Download URL: imodelsx-0.1-py3-none-any.whl
  • Upload date:
  • Size: 9.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.10

File hashes

Hashes for imodelsx-0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 16e4098df2478d859689ef871b925f333c873965b71fc286488debefa699adc1
MD5 c01c60c1c2076c55f2fc0be536a14f0f
BLAKE2b-256 5d78c98a1c6f43409bf46a276399b96081239044843ef8656948c82573c94fb3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page