From 5a22a0f5c03f4b33d1772a9a9dd9d90f69561a2a Mon Sep 17 00:00:00 2001
From: janEbert <janpublicebert@posteo.net>
Date: Mon, 17 Apr 2023 17:52:08 +0200
Subject: [PATCH] Fix xPos powers for arbitrary `pos_ids`

---
 bp_transformers/layers/positional_encodings.py | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/bp_transformers/layers/positional_encodings.py b/bp_transformers/layers/positional_encodings.py
index a59d694..57ac7db 100644
--- a/bp_transformers/layers/positional_encodings.py
+++ b/bp_transformers/layers/positional_encodings.py
@@ -264,10 +264,17 @@ class XPosEncoding(RotaryPositionalEncoding):
         dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,))
         x_shifted = self._prepare_sparse_input(x)
 
-        seq_len = common.get_seq_len(x, self.batch_first)
-        pos_pows = torch.arange(0, seq_len, device=x.device, dtype=x.dtype)
+        pos_pows = pos_ids.to(x.dtype)
+        if self.batch_first:
+            end_ids = pos_ids[..., -1]
+        else:
+            end_ids = pos_ids[-1]
+        end_ids = end_ids.to(x.dtype)
         if self.recenter_pos:
-            pos_pows = pos_pows - seq_len / 2
+            # We deliberately do not use `start_ids` here to _always_
+            # center around 0; otherwise `pos_ids` with `start_ids` > 0
+            # become meaningless.
+            pos_pows = pos_pows - end_ids / 2
         pos_pows = pos_pows / self.pos_scale_base
         if downscale:
             pos_pows = -pos_pows
-- 
GitLab