one more transformers lib
Project description
Simple way to use transformer models
Installation
pip install plain_transformers
Usage
Multimodal transformer example with two tokenizers:
Step one: import model and some usefull staff;
import torch
from plain_transformers.models import MultimodalTransformer
from plain_transformers.layers import PostLNMultimodalTransformerDecoder
from plain_transformers.layers import PostLNTransformerEncoder
from plain_transformers import BPEWrapper
from plain_transformers.initializations import normal_initialization
from plain_transformers.samplers.nucleus_sampler import NucleusSampler
import youtokentome as yttm
Step two: train and load tokenizers;
# train your encoder tokenizer
yttm.BPE.train(..., model='encoder_tokenizer.model')
# train your decoder tokenizer
yttm.BPE.train(..., model='decoder_tokenizer.model')
# load tokenizers
encoder_tokenizer = BPEWrapper(model='encoder_tokenizer.model')
decoder_tokenizer = BPEWrapper(model='decoder_tokenizer.model')
Step three: init out model configuration;
cfg = {
'd_model': 768,
'first_encoder': {
'first_encoder_vocab_size': encoder_tokenizer.vocab_size(),
'first_encoder_max_length': 512,
'first_encoder_pad_token_id': encoder_tokenizer.pad_id,
'first_encoder_token_type_vocab_size': 2,
'first_encoder_n_heads': 8,
'first_encoder_dim_feedforward': 2048,
'first_encoder_num_layers': 3,
},
'second_encoder': {
'second_encoder_vocab_size': encoder_tokenizer.vocab_size(),
'second_encoder_max_length': 512,
'second_encoder_pad_token_id': encoder_tokenizer.pad_id,
'second_encoder_token_type_vocab_size': 2,
'second_encoder_n_heads': 8,
'second_encoder_dim_feedforward': 2048,
'second_encoder_num_layers': 3,
},
'decoder': {
'decoder_max_length': 512,
'decoder_vocab_size': decoder_tokenizer.vocab_size(),
'decoder_pad_token_id': decoder_tokenizer.pad_id,
'decoder_token_type_vocab_size': 2,
'decoder_n_heads': 8,
'decoder_dim_feedforward': 2048,
'decoder_num_layers': 3,
},
}
Step four: initialize model and apply weight initialisation (with default parameter std=0.02);
model = MultimodalTransformer(
PostLNTransformerEncoder,
PostLNTransformerEncoder,
PostLNMultimodalTransformerDecoder,
cfg['d_model'],
**cfg['first_encoder'],
**cfg['second_encoder'],
**cfg['decoder'],
share_decoder_head_weights=True,
share_encoder_decoder_embeddings=False,
share_encoder_embeddings=True,
)
model.apply(normal_initialization)
Step five: train our model like ordinary seq2seq;
train(model, ...)
Step six: initialize Sampler and generate model answer;
sampler = NucleusSampler(model, encoder_tokenizer=(encoder_tokenizer, encoder_tokenizer), decoder_tokenizer=decoder_tokenizer)
sampler.generate('Hello Bob, what are you doing?', second_input_text='Fine, thanks!', top_k=5)
Example
You can find working example of NMT here.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file plain-transformers-0.0.1.5rc1.tar.gz.
File metadata
- Download URL: plain-transformers-0.0.1.5rc1.tar.gz
- Upload date:
- Size: 17.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d45fb3d4c040cc76759b0c092e707e8fe4ee100fb99cded947b99b27f8233bd5
|
|
| MD5 |
9ff7a890d091ff1e60012faeda02db71
|
|
| BLAKE2b-256 |
cdf064cdb25e65d25011e53b87c73041b13707466163c174b4ac324f44066a02
|
File details
Details for the file plain_transformers-0.0.1.5rc1-py3-none-any.whl.
File metadata
- Download URL: plain_transformers-0.0.1.5rc1-py3-none-any.whl
- Upload date:
- Size: 31.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.7.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0c66c3621387e74a2caa7a63c9af85bcd4e801671ffc25dfd39f5b71cfd1fb18
|
|
| MD5 |
ea88984ab395ec60d4a0976afc00669d
|
|
| BLAKE2b-256 |
aaca7c98fc2b1d2ced9dd38270367dc6dde5e7f96e0cd36bd48a0e40a5349f6c
|