Select Git revision
positional_encodings.py 10.91 KiB
import math
from typing import cast, Optional
import torch
import torch as th
from .. import common
def get_sinusoidal_encoding(
d_model: int,
start_id: int,
end_id: int,
max_wavelength_base: float = 10000.0,
) -> torch.Tensor:
"""Return a sinusoidal position encoding for inputs of shape
`(seq_len, batch_size, embedding_dim)`.
"""
# `end_id` is exclusive
position = th.arange(start_id, end_id).unsqueeze(1)
div_term = th.exp(
th.arange(0, d_model, 2)
* (-math.log(max_wavelength_base) / d_model)
)
pe = th.empty(len(position), 1, d_model)
starts_odd = start_id % 2 != 0
is_odd = d_model % 2 != 0
pe[:, 0, starts_odd::2] = th.sin(
position * div_term[:len(div_term) - (is_odd and starts_odd)])
pe[:, 0, not starts_odd::2] = th.cos(
position * div_term[:len(div_term) - (is_odd and not starts_odd)])
return pe
def prepare_pos_ids(
x: torch.Tensor,
pos_ids: Optional[torch.Tensor],
max_len: int,
batch_first: bool,
) -> torch.Tensor:
if pos_ids is None:
seq_len = common.get_seq_len(x, batch_first)
assert seq_len <= max_len, (
'input sequence length exceeds cached positional encoding max '
'length.'
)
pos_ids = th.arange(seq_len, device=x.device)
pos_ids = common.reshape_pos_ids(x, pos_ids, batch_first)
if batch_first:
start_ids = pos_ids[..., 0]
end_ids = pos_ids[..., -1]
else:
start_ids = pos_ids[0]
end_ids = pos_ids[-1]
assert (start_ids <= end_ids).all(), \
'position IDs are not monotonically increasing'
assert (
(start_ids >= 0).all()
and (end_ids < max_len).all()
), (
'given `pos_ids` exceed cached positional IDs '
'(must be 0 <= `pos_ids` < `max_len`).'
)
return pos_ids
class PositionalEncoding(th.nn.Module):
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_len: int = 4096,
max_wavelength_base: float = 10000.0,
batch_first: bool = True,
) -> None:
super().__init__()
self.dropout = th.nn.Dropout(p=dropout)
self.max_len = max_len
self.batch_first = batch_first
pe = get_sinusoidal_encoding(d_model, 0, max_len, max_wavelength_base)
self.pe = th.nn.Embedding.from_pretrained(pe.squeeze(1))
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Positionally encode the given input.
Args:
x: Sequence to positionally encode.
pos_ids: Position IDs for the sequence elements in `x`.
Assumed to be monotonically increasing.
Shape:
- x: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
- pos_ids: `(S,)` for unbatched input or to broadcast, `(N,
S)` if `batch_first=True` or `(S, N)` if
`batch_first=False`.
- output: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
where S is the sequence length, N is the batch size, E is
the number of features.
"""
pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
x = x + self.pe(pos_ids)
return self.dropout(x)
class RotaryPositionalEncoding(th.nn.Module):
"""Rotary positional encoding as introduced by
https://arxiv.org/abs/2002.04745.
`odd_even_split=False` selects an alternative formulation of RoPE
for faster throughput, as in GPT-NeoX
(https://github.com/EleutherAI/gpt-neox/blob/43ea51c2f3aeef2fc642ba401ce08844eb5a0240/megatron/model/positional_embeddings.py#L38-L92).
Results are not mathematically equivalent, because the tensor is
split into real/imaginary parts differently.
"""
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_len: int = 4096,
max_wavelength_base: float = 10000.0,
odd_even_split: bool = True,
batch_first: bool = True,
) -> None:
super().__init__()
assert d_model % 2 == 0, \
'`d_model` must be divisible by two for rotary position encoding'
self.dropout = th.nn.Dropout(p=dropout)
self.max_len = max_len
self.odd_even_split = odd_even_split
self.batch_first = batch_first
pe = get_sinusoidal_encoding(d_model, 0, max_len, max_wavelength_base)
pe = pe.squeeze(1)
sin_pe = pe[..., ::2]
sin_pe = self._prepare_sparse_sinusoid(sin_pe)
cos_pe = pe[..., 1::2]
cos_pe = self._prepare_sparse_sinusoid(cos_pe)
self.sin_pe = th.nn.Embedding.from_pretrained(sin_pe)
self.cos_pe = th.nn.Embedding.from_pretrained(cos_pe)
def _prepare_sparse_sinusoid(
self,
sinusoid_pe: torch.Tensor,
) -> torch.Tensor:
if self.odd_even_split:
# Repeat each embedding dimension:
# [[0, 1], [2, 3]] → [[0, 0, 1, 1], [2, 2, 3, 3]]
sinusoid_pe = sinusoid_pe.repeat_interleave(2, dim=-1)
else:
sinusoid_pe = sinusoid_pe.repeat(
(1,) * (sinusoid_pe.dim() - 1) + (2,),
)
return sinusoid_pe
def _prepare_sparse_input(self, x: torch.Tensor) -> torch.Tensor:
if self.odd_even_split:
x_shifted = x[..., ::2]
neg_x_shifted = -x[..., 1::2]
all_x_shifted = th.stack([neg_x_shifted, x_shifted], dim=-1)
else:
x_first, x_last = th.chunk(x, 2, dim=-1)
all_x_shifted = th.cat([-x_last, x_first], dim=-1)
return all_x_shifted.reshape_as(x)
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Positionally encode the given input.
Args:
x: Sequence to positionally encode. Assumed to be a
projected query/key.
pos_ids: Position IDs for the sequence elements in `x`.
Assumed to be monotonically increasing.
Shape:
- x: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
- pos_ids: `(S,)` for unbatched input or to broadcast, `(N,
S)` if `batch_first=True` or `(S, N)` if
`batch_first=False`.
- output: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
where S is the sequence length, N is the batch size, E is
the number of features.
"""
pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
cos_pe = self.cos_pe(pos_ids)
sin_pe = self.sin_pe(pos_ids)
x_shifted = self._prepare_sparse_input(x)
return x * cos_pe + x_shifted * sin_pe
class XPosEncoding(RotaryPositionalEncoding):
"""xPos positional encoding as introduced by
https://arxiv.org/abs/2212.10554.
Similar to the official implementation at
https://github.com/sunyt32/torchscale, this slightly deviates from
the paper formulation by scaling the position-related powers applied
to the dampening factor ξ by an additional factor.
Similarly, the position-related powers are centered around 0 instead
of `seq_len / 2` by default.
"""
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_len: int = 4096,
max_wavelength_base: float = 10000.0,
pos_scale_base: float = 512.0,
recenter_pos: bool = True,
gamma: float = 0.4,
batch_first: bool = True,
) -> None:
super().__init__(
d_model,
dropout=dropout,
max_len=max_len,
max_wavelength_base=max_wavelength_base,
batch_first=batch_first,
)
self.pos_scale_base = pos_scale_base
self.recenter_pos = recenter_pos
dampening = (
(torch.arange(0, d_model, 2) + gamma * d_model)
/ ((1.0 + gamma) * d_model)
).repeat_interleave(2)
self.register_buffer('dampening', dampening)
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
downscale: bool = False,
) -> torch.Tensor:
"""Positionally encode the given input.
Args:
x: Sequence to positionally encode. Assumed to be a
projected query/key.
downscale: Whether to invert the dampening factor.
Shape:
- x: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
- output: `(S, E)` for unbatched input, `(N, S, E)` if
`batch_first=True` or `(S, N, E)` if `batch_first=False`.
where S is the sequence length, N is the batch size, E is
the number of features.
"""
pos_ids = prepare_pos_ids(x, pos_ids, self.max_len, self.batch_first)
cos_pe = self.cos_pe(pos_ids)
sin_pe = self.sin_pe(pos_ids)
dampening = cast(th.Tensor, self.dampening)
dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,))
x_shifted = self._prepare_sparse_input(x)
pos_pows = pos_ids.to(x.dtype)
if self.batch_first:
end_ids = pos_ids[..., -1]
else:
end_ids = pos_ids[-1]
end_ids = end_ids.to(x.dtype)
if self.recenter_pos:
# We deliberately do not use `start_ids` here to _always_
# center around 0; otherwise `pos_ids` with `start_ids` > 0
# become meaningless.
pos_pows = pos_pows - end_ids / 2
pos_pows = pos_pows / self.pos_scale_base
if downscale:
pos_pows = -pos_pows
if self.batch_first:
pos_pows = pos_pows.reshape((1,) * (x.dim() - 2) + (-1, 1))
else:
pos_pows = pos_pows.reshape((-1,) + (1,) * (x.dim() - 1))
if x.dim() < 3:
dampening = dampening.squeeze(0)
if self.batch_first:
pos_pows = pos_pows.squeeze(0)
else:
pos_pows = pos_pows.squeeze(-2)
dampening = dampening**pos_pows
return (x * cos_pe + x_shifted * sin_pe) * dampening
class DownscalingXPosEncoding(th.nn.Module):
def __init__(
self,
xpos: XPosEncoding,
) -> None:
super().__init__()
self.xpos = xpos
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
downscale: bool = True,
) -> torch.Tensor:
return self.xpos(x, pos_ids, downscale=downscale)