Skip to content
Snippets Groups Projects
Unverified Commit e3510910 authored by Thomas Baumann's avatar Thomas Baumann Committed by GitHub
Browse files

Added caching decorator (#554)

* Implemented caching wrapper for spectral helper

* Added test for caching decorator
parent c5fbd092
No related branches found
No related tags found
No related merge requests found
Pipeline #283512 passed
...@@ -2,6 +2,54 @@ import numpy as np ...@@ -2,6 +2,54 @@ import numpy as np
import scipy import scipy
from pySDC.implementations.datatype_classes.mesh import mesh from pySDC.implementations.datatype_classes.mesh import mesh
from scipy.special import factorial from scipy.special import factorial
from functools import wraps
def cache(func):
"""
Decorator for caching return values of functions.
This is very similar to `functools.cache`, but without the memory leaks (see
https://docs.astral.sh/ruff/rules/cached-instance-method/).
Example:
.. code-block:: python
num_calls = 0
@cache
def increment(x):
num_calls += 1
return x + 1
increment(0) # returns 1, num_calls = 1
increment(1) # returns 2, num_calls = 2
increment(0) # returns 1, num_calls = 2
Args:
func (function): The function you want to cache the return value of
Returns:
return value of func
"""
attr_cache = f"_{func.__name__}_cache"
@wraps(func)
def wrapper(self, *args, **kwargs):
if not hasattr(self, attr_cache):
setattr(self, attr_cache, {})
cache = getattr(self, attr_cache)
key = (args, frozenset(kwargs.items()))
if key in cache:
return cache[key]
result = func(self, *args, **kwargs)
cache[key] = result
return result
return wrapper
class SpectralHelper1D: class SpectralHelper1D:
...@@ -203,7 +251,6 @@ class ChebychevHelper(SpectralHelper1D): ...@@ -203,7 +251,6 @@ class ChebychevHelper(SpectralHelper1D):
if self.transform_type == 'fft': if self.transform_type == 'fft':
self.get_fft_utils() self.get_fft_utils()
self.cache = {}
self.norm = self.get_norm() self.norm = self.get_norm()
def get_1dgrid(self): def get_1dgrid(self):
...@@ -221,6 +268,7 @@ class ChebychevHelper(SpectralHelper1D): ...@@ -221,6 +268,7 @@ class ChebychevHelper(SpectralHelper1D):
"""Get the domain in spectral space""" """Get the domain in spectral space"""
return self.xp.arange(self.N) return self.xp.arange(self.N)
@cache
def get_conv(self, name, N=None): def get_conv(self, name, N=None):
''' '''
Get conversion matrix between different kinds of polynomials. The supported kinds are Get conversion matrix between different kinds of polynomials. The supported kinds are
...@@ -238,9 +286,6 @@ class ChebychevHelper(SpectralHelper1D): ...@@ -238,9 +286,6 @@ class ChebychevHelper(SpectralHelper1D):
Returns: Returns:
scipy.sparse: Sparse conversion matrix scipy.sparse: Sparse conversion matrix
''' '''
if name in self.cache.keys() and not N:
return self.cache[name]
N = N if N else self.N N = N if N else self.N
sp = self.sparse_lib sp = self.sparse_lib
xp = self.xp xp = self.xp
...@@ -271,7 +316,6 @@ class ChebychevHelper(SpectralHelper1D): ...@@ -271,7 +316,6 @@ class ChebychevHelper(SpectralHelper1D):
except NotImplementedError: except NotImplementedError:
raise NotImplementedError from E raise NotImplementedError from E
self.cache[name] = mat
return mat return mat
def get_basis_change_matrix(self, conv='T2T', **kwargs): def get_basis_change_matrix(self, conv='T2T', **kwargs):
......
...@@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs): ...@@ -551,6 +551,30 @@ def test_dealias_MPI(num_procs, axis, bx, bz, nx=32, nz=64, **kwargs):
run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias') run_MPI_test(num_procs=num_procs, axis=axis, nx=nx, nz=nz, bx=bx, bz=bz, test='dealias')
@pytest.mark.base
def test_cache_decorator():
from pySDC.helpers.spectral_helper import cache
import numpy as np
class Dummy:
num_calls = 0
@cache
def increment(self, x):
self.num_calls += 1
return x + 1
dummy = Dummy()
values = [0, 1, 1, 0, 3, 1, 2]
unique_vals = np.unique(values)
for x in values:
assert dummy.increment(x) == x + 1
assert dummy.num_calls < len(values)
assert dummy.num_calls == len(unique_vals)
if __name__ == '__main__': if __name__ == '__main__':
str_to_bool = lambda me: False if me == 'False' else True str_to_bool = lambda me: False if me == 'False' else True
str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(',')) str_to_tuple = lambda arg: tuple(int(me) for me in arg.split(','))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment