diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py
index d2da9fbcbeb45ebeae4af27cb0ca30ea95f5cf90..a801503a7c53302248bb50a261d1ca6bfc8cca62 100644
--- a/pySDC/helpers/spectral_helper.py
+++ b/pySDC/helpers/spectral_helper.py
@@ -1853,6 +1853,26 @@ class SpectralHelper:
         """
         return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
 
+    def expand_matrix_ND(self, matrix, aligned):
+        sp = self.sparse_lib
+        axes = np.delete(np.arange(self.ndim), aligned)
+        ndim = len(axes) + 1
+
+        if ndim == 1:
+            return matrix
+        elif ndim == 2:
+            axis = axes[0]
+            I1D = sp.eye(self.axes[axis].N)
+
+            mats = [None] * ndim
+            mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
+            mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
+
+            return sp.kron(*mats)
+
+        else:
+            raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
+
     def get_filter_matrix(self, axis, **kwargs):
         """
         Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
@@ -1878,31 +1898,10 @@ class SpectralHelper:
         Returns:
             sparse differentiation matrix
         """
-        sp = self.sparse_lib
-        ndim = self.ndim
-
-        if ndim == 1:
-            D = self.axes[0].get_differentiation_matrix(**kwargs)
-        elif ndim == 2:
-            for axis in axes:
-                axis2 = (axis + 1) % ndim
-                D1D = self.axes[axis].get_differentiation_matrix(**kwargs)
-
-                if len(axes) > 1:
-                    I1D = sp.eye(self.axes[axis2].N)
-                else:
-                    I1D = self.axes[axis2].get_Id()
-
-                mats = [None] * ndim
-                mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
-                mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
-
-                if axis == axes[0]:
-                    D = sp.kron(*mats)
-                else:
-                    D = D @ sp.kron(*mats)
-        else:
-            raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
+        D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
+        for axis in axes[1:]:
+            _D = self.axes[axis].get_differentiation_matrix(**kwargs)
+            D = D @ self.expand_matrix_ND(_D, axis)
 
         return D
 
@@ -1916,31 +1915,10 @@ class SpectralHelper:
         Returns:
             sparse integration matrix
         """
-        sp = self.sparse_lib
-        ndim = len(self.axes)
-
-        if ndim == 1:
-            S = self.axes[0].get_integration_matrix()
-        elif ndim == 2:
-            for axis in axes:
-                axis2 = (axis + 1) % ndim
-                S1D = self.axes[axis].get_integration_matrix()
-
-                if len(axes) > 1:
-                    I1D = sp.eye(self.axes[axis2].N)
-                else:
-                    I1D = self.axes[axis2].get_Id()
-
-                mats = [None] * ndim
-                mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
-                mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
-
-                if axis == axes[0]:
-                    S = sp.kron(*mats)
-                else:
-                    S = S @ sp.kron(*mats)
-        else:
-            raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
+        S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
+        for axis in axes[1:]:
+            _S = self.axes[axis].get_integration_matrix()
+            S = S @ self.expand_matrix_ND(_S, axis)
 
         return S
 
@@ -1951,27 +1929,10 @@ class SpectralHelper:
         Returns:
             sparse identity matrix
         """
-        sp = self.sparse_lib
-        ndim = self.ndim
-        I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)
-
-        if ndim == 1:
-            I = self.axes[0].get_Id()
-        elif ndim == 2:
-            for axis in range(ndim):
-                axis2 = (axis + 1) % ndim
-                I1D = self.axes[axis].get_Id()
-
-                I1D2 = sp.eye(self.axes[axis2].N)
-
-                mats = [None] * ndim
-                mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
-                mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)
-
-                I = I @ sp.kron(*mats)
-        else:
-            raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')
-
+        I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
+        for axis in range(1, self.ndim):
+            _I = self.axes[axis].get_Id()
+            I = I @ self.expand_matrix_ND(_I, axis)
         return I
 
     def get_Dirichlet_recombination_matrix(self, axis=-1):
@@ -1984,26 +1945,8 @@ class SpectralHelper:
         Returns:
             sparse matrix
         """
-        sp = self.sparse_lib
-        ndim = len(self.axes)
-
-        if ndim == 1:
-            C = self.axes[0].get_Dirichlet_recombination_matrix()
-        elif ndim == 2:
-            axis2 = (axis + 1) % ndim
-            C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
-
-            I1D = self.axes[axis2].get_Id()
-
-            mats = [None] * ndim
-            mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
-            mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
-
-            C = sp.kron(*mats)
-        else:
-            raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
-
-        return C
+        C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
+        return self.expand_matrix_ND(C1D, axis)
 
     def get_basis_change_matrix(self, axes=None, **kwargs):
         """
@@ -2018,30 +1961,9 @@ class SpectralHelper:
         """
         axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
 
-        sp = self.sparse_lib
-        ndim = len(self.axes)
-
-        if ndim == 1:
-            C = self.axes[0].get_basis_change_matrix(**kwargs)
-        elif ndim == 2:
-            for axis in axes:
-                axis2 = (axis + 1) % ndim
-                C1D = self.axes[axis].get_basis_change_matrix(**kwargs)
-
-                if len(axes) > 1:
-                    I1D = sp.eye(self.axes[axis2].N)
-                else:
-                    I1D = self.axes[axis2].get_Id()
-
-                mats = [None] * ndim
-                mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
-                mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
-
-                if axis == axes[0]:
-                    C = sp.kron(*mats)
-                else:
-                    C = C @ sp.kron(*mats)
-        else:
-            raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
+        C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
+        for axis in axes[1:]:
+            _C = self.axes[axis].get_basis_change_matrix(**kwargs)
+            C = C @ self.expand_matrix_ND(_C, axis)
 
         return C