Skip to main content

A transfer learning library for pre-trained transformers.

Project description

plamtral

PLaMTraL - A transfer learning library for pre-trained transformers.

Installation

Install plamtral with pip:

pip install plamtral

Features

Fine-tuning

Fine-tuning large pretrained language models on downstream tasks remains the de-facto learning paradigm in NLP. However, several fine tuning approaces exist other than the usual vanilla variant, which can be more effective or efficient. The fine tuning techniques provided in this package are:

  • BitFit - a sparse fine tuning method where only the bias terms of the model (or a subset of them) are being modified. Reference: https://arxiv.org/pdf/2106.10199.pdf.
  • Chain thaw - an approach that sequentially unfreezes and fine-tunes a single layer at a time. Reference: https://arxiv.org/pdf/1708.00524.pdf.
  • ULMFiT - an effective transfer learning method that introduces techniques (slanted triangular learning rate, disciminative fine-tuning, and gradual unfreezing) that are key for fine-tuning a language model. Reference: https://arxiv.org/pdf/1801.06146.pdf.
  • Vanilla fine tuning - the standard fine-tuning approach (fine-tune the whole model, fine-tune the last n layers, or fine-tune a specific layer).

Parameter efficient approaches

Since conventional fine-tuning approaches can become expensive as they often require the storage of a large number of parameters, recent work has proposed a variety of parameter-efficient transfer learning methods that only fine-tune a small number of (extra) parameters to attain strong performance. The parameter efficient techniques provided in this package use:

Usage/Examples

To use a GPT2 model with parallel adapters (for example):

from parameter_efficient.adapter import Model_with_parallel_adapter
from tl_lib.utils import load_dataloaders
from tl_lib.tl_train import train

# Load the GPT2 model with Parallel Adapters
model_obj = Model_with_parallel_adapter('GPT2')
# Create the train, validation and test dataloaders from the dataset file
train_loader, val_loader, test_loader = load_dataloaders('GPT2', dataset_path='path/to/dataset_file')
# Train the model
train(model_obj, train_loader, val_loader, verbose = True, model_save_name = 'path/to/model')

Requirements

  • torch 1.12.1
  • tqdm 4.64.1
  • transformers 4.24.0
  • nltk 3.7
  • torchmetrics

Authors

@Vibhu04

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

plamtral-0.0.9.tar.gz (15.2 kB view hashes)

Uploaded Source

Built Distribution

plamtral-0.0.9-py3-none-any.whl (23.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page