Skip to main content

A TensorFlow 2.0 Keras implementation of BERT.

Project description

Build Status Coverage Status Version Status Python Versions

This repo contains a TensorFlow 2.0 Keras implementation of google-research/bert with support for loading of the original pre-trained weights, and producing activations numerically identical to the one calculated by the original model.

The implementation is build from scratch using only basic tensorflow operations, following the code in google-research/bert/modeling.py (but skipping dead code and applying some simplifications). It also utilizes kpe/params-flow to reduce common Keras boilerplate code (related to passing model and layer configuration arguments).

bert-for-tf2 should work with both TensorFlow 2.0 and TensorFlow 1.14 or newer.

NEWS

  • 28.Jun.2019 - v.0.3.0 supports adapter-BERT (google-research/adapter-bert) for “Parameter-Efficient Transfer Learning for NLP”, i.e. fine-tuning small overlay adapter layers over BERT’s transformer encoders without changing the frozen BERT weights.

LICENSE

MIT. See License File.

Install

bert-for-tf2 is on the Python Package Index (PyPI):

pip install bert-for-tf2

Usage

BERT in bert-for-tf2 is implemented as a Keras layer. You could instantiate it like this:

from bert import BertModelLayer

l_bert = BertModelLayer(BertModelLayer.Params(
  vocab_size               = 16000,        # embedding params
  use_token_type           = True,
  use_position_embeddings  = True,
  token_type_vocab_size    = 2,

  num_layers               = 12,           # transformer encoder params
  hidden_size              = 768,
  hidden_dropout           = 0.1,
  intermediate_size        = 4*768,
  intermediate_activation  = "gelu",

  name                     = "bert"        # any other Keras layer params
))

or by using the bert_config.json from a pre-trained google model:

import os
import tensorflow as tf
from tensorflow.python import keras
from bert import BertModelLayer
from bert.loader import StockBertConfig, load_stock_weights

model_dir = ".models/uncased_L-12_H-768_A-12"

bert_config_file = os.path.join(model_dir, "bert_config.json")
bert_ckpt_file   = os.path.join(model_dir, "bert_model.ckpt")

with tf.io.gfile.GFile(bert_config_file, "r") as reader:
  stock_params = StockBertConfig.from_json_string(reader.read())
  bert_params  = stock_params.to_bert_model_layer_params()

l_bert = BertModelLayer.from_params(bert_params, name="bert")

now you can use the BERT layer in your Keras model like this:

from tensorflow.python import keras

max_seq_len = 128
l_input_ids      = keras.layers.Input(shape=(max_seq_len,), dtype='int32')
l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32')

output = l_bert([l_input_ids, l_token_type_ids])  # [batch_size, max_seq_len, hidden_size]

and build (or compile) your model:

model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output)
model.build(input_shape=[(None, max_seq_len), (None, max_seq_len)])

before loading the original pre-trained weights into the BERT layer:

from bert.loader import load_stock_weights

load_stock_weights(l_bert, bert_ckpt_file)

N.B. see tests/test_bert_activations.py for a complete example.

Resources

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bert-for-tf2-0.3.3.tar.gz (29.1 kB view details)

Uploaded Source

File details

Details for the file bert-for-tf2-0.3.3.tar.gz.

File metadata

  • Download URL: bert-for-tf2-0.3.3.tar.gz
  • Upload date:
  • Size: 29.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.1

File hashes

Hashes for bert-for-tf2-0.3.3.tar.gz
Algorithm Hash digest
SHA256 086b5db95df5ada82830b5662bea64269cd7e82759f9d181dfa220ecfd9ca4a3
MD5 dd937911b4f62e70cf91b4f1448a57b6
BLAKE2b-256 e6aac61e9e69a4e7dadce283b088fa909599a4e51b4ea7955eda804c44423f16

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page