From dc3eed1b0fcdf1fa372815059170a8b588477e23 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 2 Apr 2025 13:15:02 +0200 Subject: [PATCH] Small refactor of spectral helper to prepare for 3D (#541) --- pySDC/helpers/spectral_helper.py | 154 ++++++++----------------------- 1 file changed, 38 insertions(+), 116 deletions(-) diff --git a/pySDC/helpers/spectral_helper.py b/pySDC/helpers/spectral_helper.py index d2da9fbc..a801503a 100644 --- a/pySDC/helpers/spectral_helper.py +++ b/pySDC/helpers/spectral_helper.py @@ -1853,6 +1853,26 @@ class SpectralHelper: """ return M.tocsc()[self.local_slice[axis], self.local_slice[axis]] + def expand_matrix_ND(self, matrix, aligned): + sp = self.sparse_lib + axes = np.delete(np.arange(self.ndim), aligned) + ndim = len(axes) + 1 + + if ndim == 1: + return matrix + elif ndim == 2: + axis = axes[0] + I1D = sp.eye(self.axes[axis].N) + + mats = [None] * ndim + mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned) + mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) + + return sp.kron(*mats) + + else: + raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!') + def get_filter_matrix(self, axis, **kwargs): """ Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are @@ -1878,31 +1898,10 @@ class SpectralHelper: Returns: sparse differentiation matrix """ - sp = self.sparse_lib - ndim = self.ndim - - if ndim == 1: - D = self.axes[0].get_differentiation_matrix(**kwargs) - elif ndim == 2: - for axis in axes: - axis2 = (axis + 1) % ndim - D1D = self.axes[axis].get_differentiation_matrix(**kwargs) - - if len(axes) > 1: - I1D = sp.eye(self.axes[axis2].N) - else: - I1D = self.axes[axis2].get_Id() - - mats = [None] * ndim - mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis) - mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) - - if axis == axes[0]: - D = sp.kron(*mats) - else: - D = D @ sp.kron(*mats) - else: - raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!') + D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0]) + for axis in axes[1:]: + _D = self.axes[axis].get_differentiation_matrix(**kwargs) + D = D @ self.expand_matrix_ND(_D, axis) return D @@ -1916,31 +1915,10 @@ class SpectralHelper: Returns: sparse integration matrix """ - sp = self.sparse_lib - ndim = len(self.axes) - - if ndim == 1: - S = self.axes[0].get_integration_matrix() - elif ndim == 2: - for axis in axes: - axis2 = (axis + 1) % ndim - S1D = self.axes[axis].get_integration_matrix() - - if len(axes) > 1: - I1D = sp.eye(self.axes[axis2].N) - else: - I1D = self.axes[axis2].get_Id() - - mats = [None] * ndim - mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis) - mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) - - if axis == axes[0]: - S = sp.kron(*mats) - else: - S = S @ sp.kron(*mats) - else: - raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!') + S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0]) + for axis in axes[1:]: + _S = self.axes[axis].get_integration_matrix() + S = S @ self.expand_matrix_ND(_S, axis) return S @@ -1951,27 +1929,10 @@ class SpectralHelper: Returns: sparse identity matrix """ - sp = self.sparse_lib - ndim = self.ndim - I = sp.eye(np.prod(self.init[0][1:]), dtype=complex) - - if ndim == 1: - I = self.axes[0].get_Id() - elif ndim == 2: - for axis in range(ndim): - axis2 = (axis + 1) % ndim - I1D = self.axes[axis].get_Id() - - I1D2 = sp.eye(self.axes[axis2].N) - - mats = [None] * ndim - mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis) - mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2) - - I = I @ sp.kron(*mats) - else: - raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!') - + I = self.expand_matrix_ND(self.axes[0].get_Id(), 0) + for axis in range(1, self.ndim): + _I = self.axes[axis].get_Id() + I = I @ self.expand_matrix_ND(_I, axis) return I def get_Dirichlet_recombination_matrix(self, axis=-1): @@ -1984,26 +1945,8 @@ class SpectralHelper: Returns: sparse matrix """ - sp = self.sparse_lib - ndim = len(self.axes) - - if ndim == 1: - C = self.axes[0].get_Dirichlet_recombination_matrix() - elif ndim == 2: - axis2 = (axis + 1) % ndim - C1D = self.axes[axis].get_Dirichlet_recombination_matrix() - - I1D = self.axes[axis2].get_Id() - - mats = [None] * ndim - mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) - mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) - - C = sp.kron(*mats) - else: - raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') - - return C + C1D = self.axes[axis].get_Dirichlet_recombination_matrix() + return self.expand_matrix_ND(C1D, axis) def get_basis_change_matrix(self, axes=None, **kwargs): """ @@ -2018,30 +1961,9 @@ class SpectralHelper: """ axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes - sp = self.sparse_lib - ndim = len(self.axes) - - if ndim == 1: - C = self.axes[0].get_basis_change_matrix(**kwargs) - elif ndim == 2: - for axis in axes: - axis2 = (axis + 1) % ndim - C1D = self.axes[axis].get_basis_change_matrix(**kwargs) - - if len(axes) > 1: - I1D = sp.eye(self.axes[axis2].N) - else: - I1D = self.axes[axis2].get_Id() - - mats = [None] * ndim - mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis) - mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2) - - if axis == axes[0]: - C = sp.kron(*mats) - else: - C = C @ sp.kron(*mats) - else: - raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!') + C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0]) + for axis in axes[1:]: + _C = self.axes[axis].get_basis_change_matrix(**kwargs) + C = C @ self.expand_matrix_ND(_C, axis) return C -- GitLab