Skip to content
Snippets Groups Projects
Select Git revision
  • main
  • instances/2025_05
  • instances/2024_11
3 results

index.d.ts

Blame
  • positional_encodings.py 10.91 KiB
    import math
    from typing import cast, Optional
    
    import torch
    import torch as th
    
    from .. import common
    
    
    def get_sinusoidal_encoding(
            d_model: int,
            start_id: int,
            end_id: int,
            max_wavelength_base: float = 10000.0,
    ) -> torch.Tensor:
        """Return a sinusoidal position encoding for inputs of shape
        `(seq_len, batch_size, embedding_dim)`.
        """
        # `end_id` is exclusive
        position = th.arange(start_id, end_id).unsqueeze(1)
        div_term = th.exp(
            th.arange(0, d_model, 2)
            * (-math.log(max_wavelength_base) / d_model)
        )
        pe = th.empty(len(position), 1, d_model)
        starts_odd = start_id % 2 != 0
        is_odd = d_model % 2 != 0
        pe[:, 0, starts_odd::2] = th.sin(
            position * div_term[:len(div_term) - (is_odd and starts_odd)])
        pe[:, 0, not starts_odd::2] = th.cos(
            position * div_term[:len(div_term) - (is_odd and not starts_odd)])
        return pe
    
    
    def prepare_pos_ids(
            x: torch.Tensor,
            pos_ids: Optional[torch.Tensor],
            max_len: int,
            batch_first: bool,
    ) -> torch.Tensor:
        if pos_ids is None:
            seq_len = common.get_seq_len(x, batch_first)
            assert seq_len <= max_len, (
                'input sequence length exceeds cached positional encoding max '
                'length.'
            )
            pos_ids = th.arange(seq_len, device=x.device)
    
        pos_ids = common.reshape_pos_ids(x, pos_ids, batch_first)
    
        if batch_first:
            start_ids = pos_ids[..., 0]
            end_ids = pos_ids[..., -1]
        else:
            start_ids = pos_ids[0]
            end_ids = pos_ids[-1]
    
        assert (start_ids <= end_ids).all(), \
            'position IDs are not monotonically increasing'
        assert (
            (start_ids >= 0).all()
            and (end_ids < max_len).all()
        ), (
            'given `pos_ids` exceed cached positional IDs '
            '(must be 0 <= `pos_ids` < `max_len`).'
        )
        return pos_ids
    
    
    class PositionalEncoding(th.nn.Module):
        def __init__(
                self,
                d_model: int,
                dropout: float = 0.1,
                max_len: int = 4096,
                max_wavelength_base: float = 10000.0,
                batch_first: bool = True,
        ) -> None:
            super().__init__()
            self.dropout = th.nn.Dropout(p=dropout)
            self.max_len = max_len
            self.batch_first = batch_first
    
            pe = get_sinusoidal_encoding(d_model, 0, max_len, max_wavelength_base)
            self.pe = th.nn.Embedding.from_pretrained(pe.squeeze(1))
    
        def forward(
                self,
                x: torch.Tensor,
                pos_ids: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
            """Positionally encode the given input.
    
            Args:
                x: Sequence to positionally encode.
                pos_ids: Position IDs for the sequence elements in `x`.
                    Assumed to be monotonically increasing.
    
            Shape:
                - x: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
                - pos_ids: `(S,)` for unbatched input or to broadcast, `(N,
                  S)` if `batch_first=True` or `(S, N)` if
                  `batch_first=False`.
    
                - output: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                where S is the sequence length, N is the batch size, E is
                the number of features.
            """
            pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
            x = x + self.pe(pos_ids)
            return self.dropout(x)
    
    
    class RotaryPositionalEncoding(th.nn.Module):
        """Rotary positional encoding as introduced by
        https://arxiv.org/abs/2002.04745.
    
        `odd_even_split=False` selects an alternative formulation of RoPE
        for faster throughput, as in GPT-NeoX
        (https://github.com/EleutherAI/gpt-neox/blob/43ea51c2f3aeef2fc642ba401ce08844eb5a0240/megatron/model/positional_embeddings.py#L38-L92).
        Results are not mathematically equivalent, because the tensor is
        split into real/imaginary parts differently.
        """
    
        def __init__(
                self,
                d_model: int,
                dropout: float = 0.1,
                max_len: int = 4096,
                max_wavelength_base: float = 10000.0,
                odd_even_split: bool = True,
                batch_first: bool = True,
        ) -> None:
            super().__init__()
    
            assert d_model % 2 == 0, \
                '`d_model` must be divisible by two for rotary position encoding'
    
            self.dropout = th.nn.Dropout(p=dropout)
            self.max_len = max_len
            self.odd_even_split = odd_even_split
            self.batch_first = batch_first
    
            pe = get_sinusoidal_encoding(d_model, 0, max_len, max_wavelength_base)
            pe = pe.squeeze(1)
            sin_pe = pe[..., ::2]
            sin_pe = self._prepare_sparse_sinusoid(sin_pe)
    
            cos_pe = pe[..., 1::2]
            cos_pe = self._prepare_sparse_sinusoid(cos_pe)
    
            self.sin_pe = th.nn.Embedding.from_pretrained(sin_pe)
            self.cos_pe = th.nn.Embedding.from_pretrained(cos_pe)
    
        def _prepare_sparse_sinusoid(
                self,
                sinusoid_pe: torch.Tensor,
        ) -> torch.Tensor:
            if self.odd_even_split:
                # Repeat each embedding dimension:
                # [[0, 1], [2, 3]] → [[0, 0, 1, 1], [2, 2, 3, 3]]
                sinusoid_pe = sinusoid_pe.repeat_interleave(2, dim=-1)
            else:
                sinusoid_pe = sinusoid_pe.repeat(
                    (1,) * (sinusoid_pe.dim() - 1) + (2,),
                )
            return sinusoid_pe
    
        def _prepare_sparse_input(self, x: torch.Tensor) -> torch.Tensor:
            if self.odd_even_split:
                x_shifted = x[..., ::2]
                neg_x_shifted = -x[..., 1::2]
                all_x_shifted = th.stack([neg_x_shifted, x_shifted], dim=-1)
            else:
                x_first, x_last = th.chunk(x, 2, dim=-1)
                all_x_shifted = th.cat([-x_last, x_first], dim=-1)
            return all_x_shifted.reshape_as(x)
    
        def forward(
                self,
                x: torch.Tensor,
                pos_ids: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
            """Positionally encode the given input.
    
            Args:
                x: Sequence to positionally encode. Assumed to be a
                    projected query/key.
                pos_ids: Position IDs for the sequence elements in `x`.
                    Assumed to be monotonically increasing.
    
            Shape:
                - x: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
                - pos_ids: `(S,)` for unbatched input or to broadcast, `(N,
                  S)` if `batch_first=True` or `(S, N)` if
                  `batch_first=False`.
    
                - output: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                where S is the sequence length, N is the batch size, E is
                the number of features.
            """
            pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
            cos_pe = self.cos_pe(pos_ids)
            sin_pe = self.sin_pe(pos_ids)
    
            x_shifted = self._prepare_sparse_input(x)
            return x * cos_pe + x_shifted * sin_pe
    
    
    class XPosEncoding(RotaryPositionalEncoding):
        """xPos positional encoding as introduced by
        https://arxiv.org/abs/2212.10554.
    
        Similar to the official implementation at
        https://github.com/sunyt32/torchscale, this slightly deviates from
        the paper formulation by scaling the position-related powers applied
        to the dampening factor ξ by an additional factor.
        Similarly, the position-related powers are centered around 0 instead
        of `seq_len / 2` by default.
        """
    
        def __init__(
                self,
                d_model: int,
                dropout: float = 0.1,
                max_len: int = 4096,
                max_wavelength_base: float = 10000.0,
                pos_scale_base: float = 512.0,
                recenter_pos: bool = True,
                gamma: float = 0.4,
                batch_first: bool = True,
        ) -> None:
            super().__init__(
                d_model,
                dropout=dropout,
                max_len=max_len,
                max_wavelength_base=max_wavelength_base,
                batch_first=batch_first,
            )
            self.pos_scale_base = pos_scale_base
            self.recenter_pos = recenter_pos
    
            dampening = (
                (torch.arange(0, d_model, 2) + gamma * d_model)
                / ((1.0 + gamma) * d_model)
            ).repeat_interleave(2)
            self.register_buffer('dampening', dampening)
    
        def forward(
                self,
                x: torch.Tensor,
                pos_ids: Optional[torch.Tensor] = None,
                downscale: bool = False,
        ) -> torch.Tensor:
            """Positionally encode the given input.
    
            Args:
                x: Sequence to positionally encode. Assumed to be a
                    projected query/key.
                downscale: Whether to invert the dampening factor.
    
            Shape:
                - x: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                - output: `(S, E)` for unbatched input, `(N, S, E)` if
                  `batch_first=True` or `(S, N, E)` if `batch_first=False`.
    
                where S is the sequence length, N is the batch size, E is
                the number of features.
            """
            pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
            cos_pe = self.cos_pe(pos_ids)
            sin_pe = self.sin_pe(pos_ids)
            dampening = cast(th.Tensor, self.dampening)
            dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,))
            x_shifted = self._prepare_sparse_input(x)
    
            pos_pows = pos_ids.to(x.dtype)
            if self.batch_first:
                end_ids = pos_ids[..., -1]
            else:
                end_ids = pos_ids[-1]
            end_ids = end_ids.to(x.dtype)
            if self.recenter_pos:
                # We deliberately do not use `start_ids` here to _always_
                # center around 0; otherwise `pos_ids` with `start_ids` > 0
                # become meaningless.
                pos_pows = pos_pows - end_ids / 2
            pos_pows = pos_pows / self.pos_scale_base
            if downscale:
                pos_pows = -pos_pows
    
            if self.batch_first:
                pos_pows = pos_pows.reshape((1,) * (x.dim() - 2) + (-1, 1))
            else:
                pos_pows = pos_pows.reshape((-1,) + (1,) * (x.dim() - 1))
    
            if x.dim() < 3:
                dampening = dampening.squeeze(0)
                if self.batch_first:
                    pos_pows = pos_pows.squeeze(0)
                else:
                    pos_pows = pos_pows.squeeze(-2)
    
            dampening = dampening**pos_pows
    
            return (x * cos_pe + x_shifted * sin_pe) * dampening
    
    
    class DownscalingXPosEncoding(th.nn.Module):
        def __init__(
                self,
                xpos: XPosEncoding,
        ) -> None:
            super().__init__()
            self.xpos = xpos
    
        def forward(
                self,
                x: torch.Tensor,
                pos_ids: Optional[torch.Tensor] = None,
                downscale: bool = True,
        ) -> torch.Tensor:
            return self.xpos(x, pos_ids, downscale=downscale)