Skip to main content

Python package to easily retrain OpenAI's GPT-2 text-generating model on new texts.

Project description


A simple Python package that wraps existing model fine-tuning and generation scripts for OpenAI GPT-2 text generation model (specifically the "small", 117M hyperparameter version). Additionally, this package allows easier generation of text, generating to a file for easy curation, allowing for prefixes to force the text to start with a given phrase.

## Usage

An example for downloading the model to the local system, fineturning it on a dataset. and generating some text.

Warning: the pretrained model, and thus any finetuned model, is 500 MB!

```python
import gpt_2_simple as gpt2

gpt2.download_gpt2() # model is saved into current directory under /models/117M/

sess = gpt2.start_tf_sess()
gpt2.finetune(sess, 'shakespeare.txt', steps=1000) # steps is max number of training steps

gpt2.generate(sess)
```

The generated model checkpoints are by default in `/checkpoint/run1`. If you want to load a model from that folder and generate text from it:

```python
import gpt_2_simple as gpt2

sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess)

gpt2.generate(sess)
```

As with textgenrnn, you can generate and save text for later use (e.g. an API or a bot) by using the `return_as_list` parameter.

```python
single_text = gpt2.generate(sess, return_as_list=True)[0]
print(single_text)
```

You can pass a `run_name` parameter to `finetune` and `load_gpt2` if you want to store/load multiple models in a `checkpoint` folder.

NB: *Restart the Python session first* if you want to finetune on another dataset or load another model.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gpt_2_simple-0.3.tar.gz (18.2 kB view details)

Uploaded Source

File details

Details for the file gpt_2_simple-0.3.tar.gz.

File metadata

  • Download URL: gpt_2_simple-0.3.tar.gz
  • Upload date:
  • Size: 18.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.7

File hashes

Hashes for gpt_2_simple-0.3.tar.gz
Algorithm Hash digest
SHA256 a47ee87df663d2fce64bbac97fc545606d2d57b7b6b0223768a6b1606b19af88
MD5 5193ec278cfa1acde7e91209010386ae
BLAKE2b-256 973858c79ba0bb928f0772bf0596bf3257b98d917d6c90a878088759c9b19fa9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page