Skip to main content

Tensorflow accessible helper for Baidu's ERNIE NLP model

Project description

ERNIE for the Rest of Us

The ERNIE model is a pre-trained model for natural language processing from Baidu. ERNIE 2.0 claimed the top GLUE score as of December 11, 2019. However, ERNIE is implemented in Baidu's Paddle deep learning framework whereas the rest of us typically use more popular framework, e.g. Tensorflow. This project provides an accessible package to recreate the ERNIE model in tensorflow initialized with the original ERNIE-trained weights.

Installation and Usages

First install ernie4us python package via pip:

pip install ernie4us

Checkout the ERNIE4us_demo.ipynb jupyter notebook on how to load and use the converted model using tensorflow API.

Extracting the ERNIE model yourself

First download a pre-trained ERNIE 2.0 model in paddle format. The avialable variations and download location can be found at ERNIE's github project.

Run the extract_ernie_params.sh script to extract model parameters and copy artifacts:

$ cd ~/ERNIE-for-the-rest-of-us
$ tar -C model_artifacts/ERNIE_Large_en_stable-2.0.0/paddle -zxf ~/Downloads/ERNIE_Large_en_stable-2.0.0.tar.gz
$ ./extract_ernie_params.sh
Usage: ./extract_ernie_params.sh ERNIE_MODEL_NAME
  ERNIE_MODEL_NAME the model identifier. Allowed identifiers are
    ERNIE_Base_en_stable-2.0.0
    ERNIE_Large_en_stable-2.0.0
$ ./extract_ernie_params.sh ERNIE_Large_en_stable-2.0.0
# .... some outputs ...
totally 392 persistables
copying artifacts...
total 2621456
-rw-r--r--@ 1 user_x  staff   330B Feb 20 16:36 ERNIE_Large_en_stable-2.0.0_config.json
-rw-r--r--@ 1 user_x  staff   226K Feb 20 16:36 ERNIE_Large_en_stable-2.0.0_vocab.txt
-rw-r--r--  1 user_x  staff   1.2G Feb 20 16:36 ERNIE_Large_en_stable-2.0.0_persistables.pkl
drwxr-xr-x  5 user_x  staff   160B Feb 20 16:36 paddle
ERNIE_Large_en_stable-2.0.0 parameters are exported to ./model_artifacts/ERNIE_Large_en_stable-2.0.0

Please note that the extraction script will install a specific version of tensorflow and thus overriding any current tensorflow version. It is adviced that one sets up a specific virtual environment for this work.

After that, one can run the ERNIE4us_verification.ipynb Jupyter note book to verify the corresponding ERNIE model recreated in tensorflow produces the same hidden features as the the orginal Paddle implementation.

To use the extracted parameters in the model, put the artifact files in a subfolder named by the model name under the path that you would be using in the ernie4s.load_ernie_model method, e.g.:

import ernie4us
# Extracted ERNIE artifacts in /user/local/lib/ernie4us/ERNIE_Large_en_stable-2.0.0/
input_builder, ernie_tf_inputs, ernie_tf_outputs = ernie4us.load_ernie_model(
  model_name='ERNIE_Large_en_stable-2.0.0', 
  model_path='/user/local/lib/ernie4us')

Tensorflow versions support

Version 0.1.15 of this library supports tensorflow 1.15. Version 2.x works only with tensorflow 2.x.

References and Credits

The modeling codes are adopted from the original BERT and modified to accept ERNIE parameters as well as migrating to tensorflow 2.0 / Keras.

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ernie4us-0.2.0-py3-none-any.whl (24.3 kB view details)

Uploaded Python 3

File details

Details for the file ernie4us-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: ernie4us-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 24.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.48.0 CPython/3.7.7

File hashes

Hashes for ernie4us-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5518f70321a230e0106448a32fb8a4f1a1418c73b66312618c0aef7c7d9cd28f
MD5 bfd8592a3aa432cc345e77b59eacda1e
BLAKE2b-256 eae27cfb5f4ad2e818a405bbc5a93c1ac5111ec985a7e07e3a202a29d2222458

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page