Skip to content
Snippets Groups Projects
Commit 7a63ca2c authored by Thomas Baumann's avatar Thomas Baumann
Browse files

Fixed GPU compatibility of polynomial test problem

parent 46b6d2b3
No related branches found
No related tags found
No related merge requests found
...@@ -28,8 +28,8 @@ class polynomial_testequation(Problem): ...@@ -28,8 +28,8 @@ class polynomial_testequation(Problem):
# invoke super init, passing number of dofs, dtype_u and dtype_f # invoke super init, passing number of dofs, dtype_u and dtype_f
super().__init__(init=(1, None, np.dtype('float64'))) super().__init__(init=(1, None, np.dtype('float64')))
self.rng = self.xp.random.RandomState(seed=seed) self.rng = np.random.RandomState(seed=seed)
self.poly = self.xp.polynomial.Polynomial(self.rng.rand(degree)) self.poly = np.polynomial.Polynomial(self.rng.rand(degree))
self._makeAttributeAndRegister('degree', 'seed', localVars=locals(), readOnly=True) self._makeAttributeAndRegister('degree', 'seed', localVars=locals(), readOnly=True)
def eval_f(self, u, t): def eval_f(self, u, t):
...@@ -50,7 +50,7 @@ class polynomial_testequation(Problem): ...@@ -50,7 +50,7 @@ class polynomial_testequation(Problem):
""" """
f = self.dtype_f(self.init) f = self.dtype_f(self.init)
f[:] = self.poly.deriv(m=1)(t) f[:] = self.xp.array(self.poly.deriv(m=1)(t))
return f return f
def solve_system(self, rhs, factor, u0, t): def solve_system(self, rhs, factor, u0, t):
...@@ -95,7 +95,7 @@ class polynomial_testequation(Problem): ...@@ -95,7 +95,7 @@ class polynomial_testequation(Problem):
The exact solution. The exact solution.
""" """
me = self.dtype_u(self.init) me = self.dtype_u(self.init)
me[:] = self.poly(t) me[:] = self.xp.array(self.poly(t))
return me return me
...@@ -125,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation): ...@@ -125,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation):
""" """
f = self.dtype_f(self.init) f = self.dtype_f(self.init)
derivative = self.poly.deriv(m=1)(t) derivative = self.xp.array(self.poly.deriv(m=1)(t))
f.impl[:] = derivative / 2 f.impl[:] = derivative / 2
f.expl[:] = derivative / 2 f.expl[:] = derivative / 2
return f return f
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment