diff --git a/bp_transformers/layers/positional_encodings.py b/bp_transformers/layers/positional_encodings.py
index a59d6941dc0a6f5585d39ab80376873c8eea0c2c..57ac7db826807f6b44c524e9ab589758726c623c 100644
--- a/bp_transformers/layers/positional_encodings.py
+++ b/bp_transformers/layers/positional_encodings.py
@@ -264,10 +264,17 @@ class XPosEncoding(RotaryPositionalEncoding):
         dampening = dampening.reshape((1,) * (x.dim() - 1) + (-1,))
         x_shifted = self._prepare_sparse_input(x)
 
-        seq_len = common.get_seq_len(x, self.batch_first)
-        pos_pows = torch.arange(0, seq_len, device=x.device, dtype=x.dtype)
+        pos_pows = pos_ids.to(x.dtype)
+        if self.batch_first:
+            end_ids = pos_ids[..., -1]
+        else:
+            end_ids = pos_ids[-1]
+        end_ids = end_ids.to(x.dtype)
         if self.recenter_pos:
-            pos_pows = pos_pows - seq_len / 2
+            # We deliberately do not use `start_ids` here to _always_
+            # center around 0; otherwise `pos_ids` with `start_ids` > 0
+            # become meaningless.
+            pos_pows = pos_pows - end_ids / 2
         pos_pows = pos_pows / self.pos_scale_base
         if downscale:
             pos_pows = -pos_pows