Skip to content
Snippets Groups Projects
Commit 7a63ca2c authored by Thomas Baumann's avatar Thomas Baumann
Browse files

Fixed GPU compatibility of polynomial test problem

parent 46b6d2b3
Branches
No related tags found
No related merge requests found
......@@ -28,8 +28,8 @@ class polynomial_testequation(Problem):
# invoke super init, passing number of dofs, dtype_u and dtype_f
super().__init__(init=(1, None, np.dtype('float64')))
self.rng = self.xp.random.RandomState(seed=seed)
self.poly = self.xp.polynomial.Polynomial(self.rng.rand(degree))
self.rng = np.random.RandomState(seed=seed)
self.poly = np.polynomial.Polynomial(self.rng.rand(degree))
self._makeAttributeAndRegister('degree', 'seed', localVars=locals(), readOnly=True)
def eval_f(self, u, t):
......@@ -50,7 +50,7 @@ class polynomial_testequation(Problem):
"""
f = self.dtype_f(self.init)
f[:] = self.poly.deriv(m=1)(t)
f[:] = self.xp.array(self.poly.deriv(m=1)(t))
return f
def solve_system(self, rhs, factor, u0, t):
......@@ -95,7 +95,7 @@ class polynomial_testequation(Problem):
The exact solution.
"""
me = self.dtype_u(self.init)
me[:] = self.poly(t)
me[:] = self.xp.array(self.poly(t))
return me
......@@ -125,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation):
"""
f = self.dtype_f(self.init)
derivative = self.poly.deriv(m=1)(t)
derivative = self.xp.array(self.poly.deriv(m=1)(t))
f.impl[:] = derivative / 2
f.expl[:] = derivative / 2
return f
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment