Skip to content
Snippets Groups Projects
Unverified Commit dc3eed1b authored by Thomas Baumann's avatar Thomas Baumann Committed by GitHub
Browse files

Small refactor of spectral helper to prepare for 3D (#541)

parent c3598e54
Branches
Tags v5.5.3
No related merge requests found
Pipeline #266120 passed
......@@ -1853,6 +1853,26 @@ class SpectralHelper:
"""
return M.tocsc()[self.local_slice[axis], self.local_slice[axis]]
def expand_matrix_ND(self, matrix, aligned):
sp = self.sparse_lib
axes = np.delete(np.arange(self.ndim), aligned)
ndim = len(axes) + 1
if ndim == 1:
return matrix
elif ndim == 2:
axis = axes[0]
I1D = sp.eye(self.axes[axis].N)
mats = [None] * ndim
mats[aligned] = self.get_local_slice_of_1D_matrix(matrix, aligned)
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
return sp.kron(*mats)
else:
raise NotImplementedError(f'Matrix expansion not implemented for {ndim} dimensions!')
def get_filter_matrix(self, axis, **kwargs):
"""
Get bandpass filter along `axis`. See the documentation `get_filter_matrix` in the 1D bases for what kwargs are
......@@ -1878,31 +1898,10 @@ class SpectralHelper:
Returns:
sparse differentiation matrix
"""
sp = self.sparse_lib
ndim = self.ndim
if ndim == 1:
D = self.axes[0].get_differentiation_matrix(**kwargs)
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
D1D = self.axes[axis].get_differentiation_matrix(**kwargs)
if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()
mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(D1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
if axis == axes[0]:
D = sp.kron(*mats)
else:
D = D @ sp.kron(*mats)
else:
raise NotImplementedError(f'Differentiation matrix not implemented for {ndim} dimension!')
D = self.expand_matrix_ND(self.axes[axes[0]].get_differentiation_matrix(**kwargs), axes[0])
for axis in axes[1:]:
_D = self.axes[axis].get_differentiation_matrix(**kwargs)
D = D @ self.expand_matrix_ND(_D, axis)
return D
......@@ -1916,31 +1915,10 @@ class SpectralHelper:
Returns:
sparse integration matrix
"""
sp = self.sparse_lib
ndim = len(self.axes)
if ndim == 1:
S = self.axes[0].get_integration_matrix()
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
S1D = self.axes[axis].get_integration_matrix()
if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()
mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(S1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
if axis == axes[0]:
S = sp.kron(*mats)
else:
S = S @ sp.kron(*mats)
else:
raise NotImplementedError(f'Integration matrix not implemented for {ndim} dimension!')
S = self.expand_matrix_ND(self.axes[axes[0]].get_integration_matrix(), axes[0])
for axis in axes[1:]:
_S = self.axes[axis].get_integration_matrix()
S = S @ self.expand_matrix_ND(_S, axis)
return S
......@@ -1951,27 +1929,10 @@ class SpectralHelper:
Returns:
sparse identity matrix
"""
sp = self.sparse_lib
ndim = self.ndim
I = sp.eye(np.prod(self.init[0][1:]), dtype=complex)
if ndim == 1:
I = self.axes[0].get_Id()
elif ndim == 2:
for axis in range(ndim):
axis2 = (axis + 1) % ndim
I1D = self.axes[axis].get_Id()
I1D2 = sp.eye(self.axes[axis2].N)
mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(I1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D2, axis2)
I = I @ sp.kron(*mats)
else:
raise NotImplementedError(f'Identity matrix not implemented for {ndim} dimension!')
I = self.expand_matrix_ND(self.axes[0].get_Id(), 0)
for axis in range(1, self.ndim):
_I = self.axes[axis].get_Id()
I = I @ self.expand_matrix_ND(_I, axis)
return I
def get_Dirichlet_recombination_matrix(self, axis=-1):
......@@ -1984,26 +1945,8 @@ class SpectralHelper:
Returns:
sparse matrix
"""
sp = self.sparse_lib
ndim = len(self.axes)
if ndim == 1:
C = self.axes[0].get_Dirichlet_recombination_matrix()
elif ndim == 2:
axis2 = (axis + 1) % ndim
C1D = self.axes[axis].get_Dirichlet_recombination_matrix()
I1D = self.axes[axis2].get_Id()
mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
C = sp.kron(*mats)
else:
raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
return C
return self.expand_matrix_ND(C1D, axis)
def get_basis_change_matrix(self, axes=None, **kwargs):
"""
......@@ -2018,30 +1961,9 @@ class SpectralHelper:
"""
axes = tuple(-i - 1 for i in range(self.ndim)) if axes is None else axes
sp = self.sparse_lib
ndim = len(self.axes)
if ndim == 1:
C = self.axes[0].get_basis_change_matrix(**kwargs)
elif ndim == 2:
for axis in axes:
axis2 = (axis + 1) % ndim
C1D = self.axes[axis].get_basis_change_matrix(**kwargs)
if len(axes) > 1:
I1D = sp.eye(self.axes[axis2].N)
else:
I1D = self.axes[axis2].get_Id()
mats = [None] * ndim
mats[axis] = self.get_local_slice_of_1D_matrix(C1D, axis)
mats[axis2] = self.get_local_slice_of_1D_matrix(I1D, axis2)
if axis == axes[0]:
C = sp.kron(*mats)
else:
C = C @ sp.kron(*mats)
else:
raise NotImplementedError(f'Basis change matrix not implemented for {ndim} dimension!')
C = self.expand_matrix_ND(self.axes[axes[0]].get_basis_change_matrix(**kwargs), axes[0])
for axis in axes[1:]:
_C = self.axes[axis].get_basis_change_matrix(**kwargs)
C = C @ self.expand_matrix_ND(_C, axis)
return C
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment