From 5a22a0f5c03f4b33d1772a9a9dd9d90f69561a2a Mon Sep 17 00:00:00 2001 From: janEbert <janpublicebert@posteo.net> Date: Mon, 17 Apr 2023 17:52:08 +0200 Subject: [PATCH] Fix xPos powers for arbitrary `pos_ids` --- bp_transformers/layers/positional_encodings.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/bp_transformers/layers/positional_encodings.py b/bp_transformers/layers/positional_encodings.py index a59d694..57ac7db 100644 --- a/bp_transformers/layers/positional_encodings.py +++ b/bp_transformers/layers/positional_encodings.py @@ -264,10 +264,17 @@ class XPosEncoding(RotaryPositionalEncoding): dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,)) x_shifted = self._prepare_sparse_input(x) - seq_len = common.get_seq_len(x, self.batch_first) - pos_pows = torch.arange(0, seq_len, device=x.device, dtype=x.dtype) + pos_pows = pos_ids.to(x.dtype) + if self.batch_first: + end_ids = pos_ids[..., -1] + else: + end_ids = pos_ids[-1] + end_ids = end_ids.to(x.dtype) if self.recenter_pos: - pos_pows = pos_pows - seq_len / 2 + # We deliberately do not use `start_ids` here to _always_ + # center around 0; otherwise `pos_ids` with `start_ids` > 0 + # become meaningless. + pos_pows = pos_pows - end_ids / 2 pos_pows = pos_pows / self.pos_scale_base if downscale: pos_pows = -pos_pows -- GitLab