Skip to content
Snippets Groups Projects
Commit 2825857e authored by Sync bot's avatar Sync bot
Browse files

Merge commit '7a63ca2c' into TEMPORARY_MERGE_PR_469

parents 5b3a12e3 7a63ca2c
Branches
Tags
No related merge requests found
Pipeline #212244 failed
......@@ -71,7 +71,7 @@ class EstimatePolynomialError(ConvergenceController):
self.add_status_variable_to_level('error_embedded_estimate')
self.add_status_variable_to_level('order_embedded_estimate')
def matmul(self, A, b):
def matmul(self, A, b, xp=np):
"""
Matrix vector multiplication, possibly MPI parallel.
The parallel implementation performs a reduce operation in every row of the matrix. While communicating the
......@@ -81,10 +81,12 @@ class EstimatePolynomialError(ConvergenceController):
Args:
A (2d np.ndarray): Matrix
b (list): Vector
xp: Either numpy or cupy
Returns:
List: Axb
"""
if self.comm:
res = [A[i, 0] * b[0] if b[i] is not None else None for i in range(A.shape[0])]
buf = b[0] * 0.0
......@@ -93,13 +95,13 @@ class EstimatePolynomialError(ConvergenceController):
send_buf = (
(A[i, index] * b[index])
if self.comm.rank != self.params.estimate_on_node - 1
else np.zeros_like(res[0])
else xp.zeros_like(res[0])
)
self.comm.Allreduce(send_buf, buf, op=self.MPI_SUM)
res[i] += buf
return res
else:
return A @ np.asarray(b)
return A @ xp.asarray(b)
def post_iteration_processing(self, controller, S, **kwargs):
"""
......@@ -118,19 +120,20 @@ class EstimatePolynomialError(ConvergenceController):
coll = L.sweep.coll
nodes = np.append(np.append(0, coll.nodes), 1.0)
estimate_on_node = self.params.estimate_on_node
xp = L.u[0].xp
if self.interpolation_matrix is None:
interpolator = LagrangeApproximation(
points=[nodes[i] for i in range(coll.num_nodes + 1) if i != estimate_on_node]
)
self.interpolation_matrix = interpolator.getInterpolationMatrix([nodes[estimate_on_node]])
self.interpolation_matrix = xp.array(interpolator.getInterpolationMatrix([nodes[estimate_on_node]]))
u = [
L.u[i].flatten() if L.u[i] is not None else L.u[i]
for i in range(coll.num_nodes + 1)
if i != estimate_on_node
]
u_inter = self.matmul(self.interpolation_matrix, u)[0].reshape(L.prob.init[0])
u_inter = self.matmul(self.interpolation_matrix, u, xp=xp)[0].reshape(L.prob.init[0])
# compute end point if needed
if estimate_on_node == len(nodes) - 1:
......
......@@ -12,7 +12,10 @@ class cupy_mesh(cp.ndarray):
CuPy-based datatype for serial or parallel meshes.
"""
def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None):
comm = None
xp = cp
def __new__(cls, init, val=0.0, **kwargs):
"""
Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.
......@@ -25,51 +28,30 @@ class cupy_mesh(cp.ndarray):
"""
if isinstance(init, cupy_mesh):
obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, strides=strides, order=order)
obj = cp.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
obj[:] = init[:]
obj._comm = init._comm
elif (
isinstance(init, tuple)
and (init[1] is None or isinstance(init[1], MPI.Intracomm))
and isinstance(init[2], cp.dtype)
):
obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], strides=strides, order=order)
obj = cp.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
obj.fill(val)
obj._comm = init[1]
else:
raise NotImplementedError(type(init))
return obj
@property
def comm(self):
"""
Getter for the communicator
"""
return self._comm
def __array_finalize__(self, obj):
"""
Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
"""
if obj is None:
return
self._comm = getattr(obj, '_comm', None)
def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
"""
Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
"""
args = []
comm = None
for _, input_ in enumerate(inputs):
if isinstance(input_, cupy_mesh):
args.append(input_.view(cp.ndarray))
comm = input_.comm
else:
args.append(input_)
results = super(cupy_mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(cupy_mesh)
if not method == 'reduce':
results._comm = comm
return results
def __abs__(self):
......
......@@ -19,6 +19,7 @@ class mesh(np.ndarray):
"""
comm = None
xp = np
def __new__(cls, init, val=0.0, **kwargs):
"""
......
......@@ -12,10 +12,19 @@ class polynomial_testequation(Problem):
dtype_u = mesh
dtype_f = mesh
xp = np
def __init__(self, degree=1, seed=26266):
def __init__(self, degree=1, seed=26266, useGPU=False):
"""Initialization routine"""
if useGPU:
import cupy as cp
from pySDC.implementations.datatype_classes.cupy_mesh import cupy_mesh
type(self).xp = cp
type(self).dtype_u = cupy_mesh
type(self).dtype_f = cupy_mesh
# invoke super init, passing number of dofs, dtype_u and dtype_f
super().__init__(init=(1, None, np.dtype('float64')))
......@@ -41,7 +50,7 @@ class polynomial_testequation(Problem):
"""
f = self.dtype_f(self.init)
f[:] = self.poly.deriv(m=1)(t)
f[:] = self.xp.array(self.poly.deriv(m=1)(t))
return f
def solve_system(self, rhs, factor, u0, t):
......@@ -86,7 +95,7 @@ class polynomial_testequation(Problem):
The exact solution.
"""
me = self.dtype_u(self.init)
me[:] = self.poly(t)
me[:] = self.xp.array(self.poly(t))
return me
......@@ -116,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation):
"""
f = self.dtype_f(self.init)
derivative = self.poly.deriv(m=1)(t)
derivative = self.xp.array(self.poly.deriv(m=1)(t))
f.impl[:] = derivative / 2
f.expl[:] = derivative / 2
return f
import pytest
def get_controller(dt, num_nodes, quad_type, useMPI):
def get_controller(dt, num_nodes, quad_type, useMPI, useGPU):
"""
Get a controller prepared for polynomial test equation
......@@ -42,7 +42,10 @@ def get_controller(dt, num_nodes, quad_type, useMPI):
sweeper_params['num_nodes'] = num_nodes
sweeper_params['comm'] = comm
problem_params = {'degree': 12}
problem_params = {
'degree': 12,
'useGPU': useGPU,
}
# initialize step parameters
step_params = {}
......@@ -82,6 +85,7 @@ def single_test(**kwargs):
'quad_type': 'RADAU-RIGHT',
'useMPI': False,
'dt': 1.0,
'useGPU': False,
**kwargs,
}
......@@ -131,8 +135,6 @@ def multiple_runs(dts, **kwargs):
dict: Errors for multiple runs
int: Order of the collocation problem
"""
from pySDC.helpers.stats_helper import get_sorted
res = {}
for dt in dts:
......@@ -187,6 +189,22 @@ def test_interpolation_error(num_nodes, quad_type):
check_order(steps, **kwargs)
@pytest.mark.cupy
@pytest.mark.parametrize('num_nodes', [2, 3, 4, 5])
@pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
def test_interpolation_error_GPU(num_nodes, quad_type):
import numpy as np
kwargs = {
'num_nodes': num_nodes,
'quad_type': quad_type,
'useMPI': False,
'useGPU': True,
}
steps = np.logspace(-1, -4, 20)
check_order(steps, **kwargs)
@pytest.mark.mpi4py
@pytest.mark.parametrize('num_nodes', [2, 5])
@pytest.mark.parametrize('quad_type', ['RADAU-RIGHT', 'GAUSS'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment