From 7a63ca2c727723ef188f67b233af11ee04ecb652 Mon Sep 17 00:00:00 2001
From: Thomas Baumann <t.baumann@fz-juelich.de>
Date: Mon, 26 Aug 2024 12:25:35 +0200
Subject: [PATCH] Fixed GPU compatibility of polynomial test problem

---
 .../problem_classes/polynomial_test_problem.py         | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/pySDC/implementations/problem_classes/polynomial_test_problem.py b/pySDC/implementations/problem_classes/polynomial_test_problem.py
index e961289ac..8183e5f98 100644
--- a/pySDC/implementations/problem_classes/polynomial_test_problem.py
+++ b/pySDC/implementations/problem_classes/polynomial_test_problem.py
@@ -28,8 +28,8 @@ class polynomial_testequation(Problem):
         # invoke super init, passing number of dofs, dtype_u and dtype_f
         super().__init__(init=(1, None, np.dtype('float64')))
 
-        self.rng = self.xp.random.RandomState(seed=seed)
-        self.poly = self.xp.polynomial.Polynomial(self.rng.rand(degree))
+        self.rng = np.random.RandomState(seed=seed)
+        self.poly = np.polynomial.Polynomial(self.rng.rand(degree))
         self._makeAttributeAndRegister('degree', 'seed', localVars=locals(), readOnly=True)
 
     def eval_f(self, u, t):
@@ -50,7 +50,7 @@ class polynomial_testequation(Problem):
         """
 
         f = self.dtype_f(self.init)
-        f[:] = self.poly.deriv(m=1)(t)
+        f[:] = self.xp.array(self.poly.deriv(m=1)(t))
         return f
 
     def solve_system(self, rhs, factor, u0, t):
@@ -95,7 +95,7 @@ class polynomial_testequation(Problem):
             The exact solution.
         """
         me = self.dtype_u(self.init)
-        me[:] = self.poly(t)
+        me[:] = self.xp.array(self.poly(t))
         return me
 
 
@@ -125,7 +125,7 @@ class polynomial_testequation_IMEX(polynomial_testequation):
         """
 
         f = self.dtype_f(self.init)
-        derivative = self.poly.deriv(m=1)(t)
+        derivative = self.xp.array(self.poly.deriv(m=1)(t))
         f.impl[:] = derivative / 2
         f.expl[:] = derivative / 2
         return f
-- 
GitLab