Lean, elegant, explorative Transformers
This package implements Transformers using standard PyTorch code. Transformer models implement enhancements and encourage best-practices according to modern Transformer research and PyTorch development.
Additionally, Transformers implement functionality to extract attention weights with minimal runtime impact.
There are also some training utilities like a cosine annealing schedule with warm-up, one of the most used learning rate schedules for Transformers in natural language processing.
Installation
python -m pip install git+<repository-uri>
PyTorch 2.0
While the code works with PyTorch 2.0 due to its backward compatibility guarantees, a version that adds explicit support for PyTorch 2.0 features is available.
python -m pip install git+<repository-uri>@pytorch-2.0
Usage
Encoder-decoder
import torch
from lee_transformers import LPETransformer
batch_size = 2
input_seq_len = 64
output_seq_len = 64
num_inputs = 4
num_outputs = 8
inputs = torch.randn((batch_size, input_seq_len, num_inputs))
targets = torch.randn((batch_size, output_seq_len, num_outputs))
encoder_decoder = LPETransformer(
num_inputs,
num_outputs,
max_seq_len=input_seq_len,
tgt_max_seq_len=output_seq_len,
)
# Automatically set up causal attention mask for the decoder part by
# specifying `tgt_mask=True`.
outputs = encoder_decoder(inputs, targets, tgt_mask=True)
assert outputs.shape == (batch_size, output_seq_len, num_outputs)
Encoder-only
import torch
from lee_transformers import LPETransformer
batch_size = 2
input_seq_len = 64
num_inputs = 4
num_outputs = 8
inputs = torch.randn((batch_size, input_seq_len, num_inputs))
encoder_only = LPETransformer(
num_inputs,
num_outputs,
max_seq_len=input_seq_len,
num_decoder_layers=0,
)
outputs = encoder_only(inputs)
assert outputs.shape == (batch_size, input_seq_len, num_outputs)
Decoder-only
import torch
from lee_transformers import LPETransformer
batch_size = 2
input_seq_len = 64
num_inputs = 4
num_outputs = 8
inputs = torch.randn((batch_size, input_seq_len, num_inputs))
decoder_only = LPETransformer(
num_inputs,
num_outputs,
max_seq_len=input_seq_len,
num_encoder_layers=0,
)
# Automatically set up causal attention mask by specifying `src_mask=True`.
outputs = decoder_only(inputs, src_mask=True)
assert outputs.shape == (batch_size, input_seq_len, num_outputs)
Attention weights (encoder-decoder)
import torch
from lee_transformers import LPETransformer
batch_size = 2
input_seq_len = 64
output_seq_len = 64
num_inputs = 4
num_outputs = 8
inputs = torch.randn((batch_size, input_seq_len, num_inputs))
targets = torch.randn((batch_size, output_seq_len, num_outputs))
encoder_decoder = LPETransformer(
num_inputs,
num_outputs,
max_seq_len=input_seq_len,
tgt_max_seq_len=output_seq_len,
)
output_tuple = encoder_decoder(
inputs,
targets,
tgt_mask=True,
return_encoder_self_attn=True,
return_decoder_self_attn=True,
return_decoder_cross_attn=True,
)
# Can also access usual outputs.
assert output_tuple.outputs.shape == (batch_size, output_seq_len, num_outputs)
assert output_tuple.encoder_self_attn.shape == (
batch_size,
encoder_decoder.encoder.num_layers,
encoder_decoder.nhead,
input_seq_len,
input_seq_len,
)
assert output_tuple.decoder_self_attn.shape == (
batch_size,
encoder_decoder.decoder.num_layers,
encoder_decoder.nhead,
output_seq_len,
output_seq_len,
)
assert output_tuple.decoder_cross_attn.shape == (
batch_size,
encoder_decoder.decoder.num_layers,
encoder_decoder.nhead,
output_seq_len,
input_seq_len,
)
Generate Documentation
python -m pip install sphinx
cd docs
make html # Other backends also available.
# Open `docs/_build/html/index.html`
Running Tests
python -m unittest