Transformer-based models implemented in tensorflow 2.x(Keras)
Project description
transformers-keras
Transformer-based models implemented in tensorflow 2.x(Keras).
Installation
pip install -U transformers-keras
Models
- Transformer
- Attention Is All You Need.
- Here is a tutorial from tensorflow:Transformer model for language understanding
- BERT
- ALBERT
Transformer
Train a new transformer:
from transformers_keras import TransformerTextFileDatasetBuilder
from transformers_keras import TransformerDefaultTokenizer
from transformers_keras import TransformerRunner
src_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_src.txt')
tgt_tokenizer = TransformerDefaultTokenizer(vocab_file='testdata/vocab_tgt.txt')
dataset_builder = TransformerTextFileDatasetBuilder(src_tokenizer, tgt_tokenizer)
model_config = {
'num_encoder_layers': 2,
'num_decoder_layers': 2,
'src_vocab_size': src_tokenizer.vocab_size,
'tgt_vocab_size': tgt_tokenizer.vocab_size,
}
runner = TransformerRunner(model_config, dataset_builder, model_dir='/tmp/transformer')
train_files = [('testdata/train.src.txt','testdata/train.tgt.txt')]
runner.train(train_files, epochs=10, callbacks=None)
BERT
Use your own data to pretrain a BERT model.
from transformers_keras import BertTFRecordDatasetBuilder
from transformers_keras import BertRunner
dataset_builder = BertTFRecordDatasetBuilder()
model_config = {
'num_layers': 6,
}
runner = BertRunner(model_config, dataset_builder, model_dir='/tmp/bert')
train_files = ['testdata/bert_custom_pretrain.tfrecord']
runner.train(train_files, epochs=10, callbacks=None)
Tips:
You need prepare your data to tfrecord format. You can use this script: create_pretraining_data.py
You can subclass
transformers_keras.tokenizers.BertTFRecordDatasetBuilder
to parse custom tfrecord examples as you need.
ALBERT
You should process your data to tfrecord format. Modify this script transformers_keras/utils/bert_tfrecord_custom_generator.py
as you need.
from transformers_keras import BertTFRecordDatasetBuilder
from transformers_keras import AlbertRunner
# ALBERT has the same data format with BERT
dataset_builder = BertTFRecordDatasetBuilder()
model_config = {
'num_layers': 6,
'num_groups': 1,
'num_layers_each_group': 1,
}
runner = AlbertRunner(model_config, dataset_builder, model_dir='/tmp/albert')
train_files = ['testdata/bert_custom_pretrain.tfrecord']
runner.train(train_files, epochs=10, callbacks=None)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
transformers_keras-0.1.0.tar.gz
(34.3 kB
view hashes)
Built Distribution
Close
Hashes for transformers_keras-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 363f627d3935a657356e9ff39b209764b125ffd10b6b06501aeb4de676e4e3c5 |
|
MD5 | 04c4eb511be96143c8389bb8cb94512d |
|
BLAKE2b-256 | b7ebc1059bd63bc25aa1cc826dd788bcff172859ff74bffdc9008bb6d40238c5 |