Skip to content
Snippets Groups Projects
Select Git revision
  • master
  • pytorch-2.0
2 results

LEE-Transformers

  • janEbert's avatar
    Jan Ebert authored
    So we do not have clashes with the built-in module.
    c680e98e
    History

    Lean, elegant, explorative Transformers

    This package implements Transformers using standard PyTorch code. Transformer models implement enhancements and encourage best-practices according to modern Transformer research and PyTorch development.

    Additionally, Transformers implement functionality to extract attention weights with minimal runtime impact.

    There are also some training utilities like a cosine annealing schedule with warm-up, one of the most used learning rate schedules for Transformers in natural language processing.

    Installation

    python -m pip install git+<repository-uri>

    PyTorch 2.0

    While the code works with PyTorch 2.0 due to its backward compatibility guarantees, a version that adds explicit support for PyTorch 2.0 features is available.

    python -m pip install git+<repository-uri>@pytorch-2.0

    Usage

    Encoder-decoder

    import torch
    
    from lee_transformers import LPETransformer
    
    
    batch_size = 2
    input_seq_len = 64
    output_seq_len = 64
    num_inputs = 4
    num_outputs = 8
    
    inputs = torch.randn((batch_size, input_seq_len, num_inputs))
    targets = torch.randn((batch_size, output_seq_len, num_outputs))
    
    encoder_decoder = LPETransformer(
        num_inputs,
        num_outputs,
        max_seq_len=input_seq_len,
        tgt_max_seq_len=output_seq_len,
    )
    # Automatically set up causal attention mask for the decoder part by
    # specifying `tgt_mask=True`.
    outputs = encoder_decoder(inputs, targets, tgt_mask=True)
    
    assert outputs.shape == (batch_size, output_seq_len, num_outputs)

    Encoder-only

    import torch
    
    from lee_transformers import LPETransformer
    
    
    batch_size = 2
    input_seq_len = 64
    num_inputs = 4
    num_outputs = 8
    
    inputs = torch.randn((batch_size, input_seq_len, num_inputs))
    
    encoder_only = LPETransformer(
        num_inputs,
        num_outputs,
        max_seq_len=input_seq_len,
        num_decoder_layers=0,
    )
    outputs = encoder_only(inputs)
    
    assert outputs.shape == (batch_size, input_seq_len, num_outputs)

    Decoder-only

    import torch
    
    from lee_transformers import LPETransformer
    
    
    batch_size = 2
    input_seq_len = 64
    num_inputs = 4
    num_outputs = 8
    
    inputs = torch.randn((batch_size, input_seq_len, num_inputs))
    
    decoder_only = LPETransformer(
        num_inputs,
        num_outputs,
        max_seq_len=input_seq_len,
        num_encoder_layers=0,
    )
    # Automatically set up causal attention mask by specifying `src_mask=True`.
    outputs = decoder_only(inputs, src_mask=True)
    
    assert outputs.shape == (batch_size, input_seq_len, num_outputs)

    Attention weights (encoder-decoder)

    import torch
    
    from lee_transformers import LPETransformer
    
    
    batch_size = 2
    input_seq_len = 64
    output_seq_len = 64
    num_inputs = 4
    num_outputs = 8
    
    inputs = torch.randn((batch_size, input_seq_len, num_inputs))
    targets = torch.randn((batch_size, output_seq_len, num_outputs))
    
    encoder_decoder = LPETransformer(
        num_inputs,
        num_outputs,
        max_seq_len=input_seq_len,
        tgt_max_seq_len=output_seq_len,
    )
    output_tuple = encoder_decoder(
        inputs,
        targets,
        tgt_mask=True,
        return_encoder_self_attn=True,
        return_decoder_self_attn=True,
        return_decoder_cross_attn=True,
    )
    # Can also access usual outputs.
    assert output_tuple.outputs.shape == (batch_size, output_seq_len, num_outputs)
    
    assert output_tuple.encoder_self_attn.shape == (
        batch_size,
        encoder_decoder.encoder.num_layers,
        encoder_decoder.nhead,
        input_seq_len,
        input_seq_len,
    )
    assert output_tuple.decoder_self_attn.shape == (
        batch_size,
        encoder_decoder.decoder.num_layers,
        encoder_decoder.nhead,
        output_seq_len,
        output_seq_len,
    )
    assert output_tuple.decoder_cross_attn.shape == (
        batch_size,
        encoder_decoder.decoder.num_layers,
        encoder_decoder.nhead,
        output_seq_len,
        input_seq_len,
    )

    Generate Documentation

    python -m pip install sphinx
    cd docs
    make html  # Other backends also available.
    # Open `docs/_build/html/index.html`

    Running Tests

    python -m unittest