Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer[DELETED]
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
BERT
Supported pretrained models:
- All the BERT models pretrained by google-research/bert
- All the BERT & RoBERTa models pretrained by ymcui/Chinese-BERT-wwm
from transformers_keras import Bert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/bert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))
# Used to fine-tuning
def build_bert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
bert = Bert.from_pretrained(pretrained_model_dir, **kwargs)
bert.trainable = trainable
sequence_outputs, pooled_output = bert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_bert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'chinese_wwm_ext_L-12_H-768_A-12'),
trainable=True)
model.summary()
ALBERT
Supported pretrained models:
- All the ALBERT models pretrained by google-research/albert
from transformers_keras import Albert
# Used to predict directly
model = Bert.from_pretrained('/path/to/pretrained/albert/model')
# segment_ids and mask inputs are optional
model.predict((input_ids, segment_ids, mask))
# or
model(inputs=(input_ids, segment_ids, mask))
# Used to fine-tuning
def build_albert_classify_model(pretrained_model_dir, trainable=True, **kwargs):
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
# segment_ids and mask inputs are optional
segment_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='segment_ids')
albert = Albert.from_pretrained(pretrained_model_dir, **kwargs)
albert.trainable = trainable
sequence_outputs, pooled_output = albert(inputs=(input_ids, segment_ids))
outputs = tf.keras.layers.Dense(2, name='output')(pooled_output)
model = tf.keras.Model(inputs=[input_ids, segment_ids], outputs=outputs)
model.compile(loss='binary_cross_entropy', optimizer='adam')
return model
model = build_albert_classify_model(
pretrained_model_dir=os.path.join(BASE_DIR, 'albert_base'),
trainable=True)
model.summary()
Load other pretrained models
If you want to load models pretrained by other implementationds, whose config and trainable weights are a little different from previous, you can subclass AbstractAdapter to adapte these models:
from transformers_keras.adapters import AbstractAdapter
from transformers_keras import Bert, Albert
# load custom bert models
class MyBertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.bert_adapter`
pass
bert = Bert.from_pretrained('/path/to/your/bert/model', adapter=MyBertAdapter())
# or, load custom albert models
class MyAlbertAdapter(AbstractAdapter):
def adapte_config(self, config_file, **kwargs):
# adapte model config here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
def adapte_weights(self, model, config, ckpt, **kwargs):
# adapte model weights here
# you can refer to `transformers_keras.adapters.albert_adapter`
pass
albert = Albert.from_pretrained('/path/to/your/albert/model', adapter=MyAlbertAdapter())
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformers_keras-0.2.2.tar.gz.
File metadata
- Download URL: transformers_keras-0.2.2.tar.gz
- Upload date:
- Size: 17.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd1a5f975d86e0370b88b25132860a07e25569f62e4350d352903b4e8334d2d9
|
|
| MD5 |
19cc5a6690c8d892bf58e42656c6c0d3
|
|
| BLAKE2b-256 |
b57e3f7fcddaef87bc90cd9155765e3c2ae8a5d7911a597001f09e622b99428a
|
File details
Details for the file transformers_keras-0.2.2-py3-none-any.whl.
File metadata
- Download URL: transformers_keras-0.2.2-py3-none-any.whl
- Upload date:
- Size: 28.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
535fa94d2306239780165bd245b3515c57cb8087588bbe6490a4b8e0210309e1
|
|
| MD5 |
906ca5cd588e6007d700836b2c9b8832
|
|
| BLAKE2b-256 |
5516a407820d4ae5bd7b555176feb1d3f3575a93227d6680b7fdacaacc31d077
|