diff --git a/bp_transformers/layers/positional_encodings.py b/bp_transformers/layers/positional_encodings.py index a59d6941dc0a6f5585d39ab80376873c8eea0c2c..57ac7db826807f6b44c524e9ab589758726c623c 100644 --- a/bp_transformers/layers/positional_encodings.py +++ b/bp_transformers/layers/positional_encodings.py @@ -264,10 +264,17 @@ class XPosEncoding(RotaryPositionalEncoding): dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,)) x_shifted = self._prepare_sparse_input(x) - seq_len = common.get_seq_len(x, self.batch_first) - pos_pows = torch.arange(0, seq_len, device=x.device, dtype=x.dtype) + pos_pows = pos_ids.to(x.dtype) + if self.batch_first: + end_ids = pos_ids[..., -1] + else: + end_ids = pos_ids[-1] + end_ids = end_ids.to(x.dtype) if self.recenter_pos: - pos_pows = pos_pows - seq_len / 2 + # We deliberately do not use `start_ids` here to _always_ + # center around 0; otherwise `pos_ids` with `start_ids` > 0 + # become meaningless. + pos_pows = pos_pows - end_ids / 2 pos_pows = pos_pows / self.pos_scale_base if downscale: pos_pows = -pos_pows