From 43cf2f5b68a9f1113c5a96fb4329d2cd273e4466 Mon Sep 17 00:00:00 2001
From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com>
Date: Fri, 11 Apr 2025 13:07:46 +0200
Subject: [PATCH] Implemented interface to FieldsIO in core problem class and
 respective hook (#542)

* Implemented interface to FieldsIO in core problem class and generic spectral

* Added comment
---
 etc/environment-mpi4py.yml                    |   1 +
 pySDC/core/problem.py                         |  12 ++
 pySDC/helpers/fieldsIO.py                     |   7 +-
 pySDC/implementations/hooks/log_solution.py   |  65 +++++++-
 .../problem_classes/TestEquation_0D.py        |  10 ++
 .../problem_classes/generic_spectral.py       |  25 +++
 pySDC/projects/GPU/configs/GS_configs.py      |   2 +-
 pySDC/projects/GPU/configs/RBC_configs.py     |   2 +-
 pySDC/projects/GPU/tests/test_configs.py      |   2 +-
 pySDC/tests/test_hooks/test_log_to_file.py    | 143 ++++++++++++++----
 10 files changed, 232 insertions(+), 37 deletions(-)

diff --git a/etc/environment-mpi4py.yml b/etc/environment-mpi4py.yml
index 16b96df6d..dd0556be4 100644
--- a/etc/environment-mpi4py.yml
+++ b/etc/environment-mpi4py.yml
@@ -14,3 +14,4 @@ dependencies:
   - pip
   - pip:
     - qmat>=0.1.8
+    - pytest-isolate-mpi
diff --git a/pySDC/core/problem.py b/pySDC/core/problem.py
index cdfabeb89..55d76b18d 100644
--- a/pySDC/core/problem.py
+++ b/pySDC/core/problem.py
@@ -80,6 +80,18 @@ class Problem(RegisterParams):
     def get_default_sweeper_class(cls):
         raise NotImplementedError(f'No default sweeper class implemented for {cls} problem!')
 
+    def setUpFieldsIO(self):
+        """
+        Set up FieldsIO for MPI with the space decomposition of this problem
+        """
+        pass
+
+    def getOutputFile(self, fileName):
+        raise NotImplementedError(f'No output implemented file for {type(self).__name__}')
+
+    def processSolutionForOutput(self, u):
+        return u
+
     def eval_f(self, u, t):
         """
         Abstract interface to RHS computation of the ODE
diff --git a/pySDC/helpers/fieldsIO.py b/pySDC/helpers/fieldsIO.py
index 5ac312ac0..fe0e8cbd0 100644
--- a/pySDC/helpers/fieldsIO.py
+++ b/pySDC/helpers/fieldsIO.py
@@ -198,9 +198,10 @@ class FieldsIO:
         assert not self.initialized, "FieldsIO already initialized"
 
         if not self.ALLOW_OVERWRITE:
-            assert not os.path.isfile(
-                self.fileName
-            ), f"file {self.fileName!r} already exists, use FieldsIO.ALLOW_OVERWRITE = True to allow overwriting"
+            if os.path.isfile(self.fileName):
+                raise FileExistsError(
+                    f"file {self.fileName!r} already exists, use FieldsIO.ALLOW_OVERWRITE = True to allow overwriting"
+                )
 
         with open(self.fileName, "w+b") as f:
             self.hBase.tofile(f)
diff --git a/pySDC/implementations/hooks/log_solution.py b/pySDC/implementations/hooks/log_solution.py
index 9cd5ba8e8..e7e450101 100644
--- a/pySDC/implementations/hooks/log_solution.py
+++ b/pySDC/implementations/hooks/log_solution.py
@@ -2,6 +2,8 @@ from pySDC.core.hooks import Hooks
 import pickle
 import os
 import numpy as np
+from pySDC.helpers.fieldsIO import FieldsIO
+from pySDC.core.errors import DataError
 
 
 class LogSolution(Hooks):
@@ -68,7 +70,7 @@ class LogSolutionAfterIteration(Hooks):
         )
 
 
-class LogToFile(Hooks):
+class LogToPickleFile(Hooks):
     r"""
     Hook for logging the solution to file after the step using pickle.
 
@@ -171,7 +173,7 @@ class LogToFile(Hooks):
             return pickle.load(file)
 
 
-class LogToFileAfterXs(LogToFile):
+class LogToPickleFileAfterXS(LogToPickleFile):
     r'''
     Log to file after certain amount of time has passed instead of after every step
     '''
@@ -200,3 +202,62 @@ class LogToFileAfterXs(LogToFile):
             }
 
         self.log_to_file(step, level_number, type(self).logging_condition(L), process_solution=process_solution)
+
+
+class LogToFile(Hooks):
+    filename = 'myRun.pySDC'
+    time_increment = 0
+    allow_overwriting = False
+
+    def __init__(self):
+        super().__init__()
+        self.outfile = None
+        self.t_next_log = 0
+        FieldsIO.ALLOW_OVERWRITE = self.allow_overwriting
+
+    def pre_run(self, step, level_number):
+        if level_number > 0:
+            return None
+        L = step.levels[level_number]
+
+        # setup outfile
+        if os.path.isfile(self.filename) and L.time > 0:
+            L.prob.setUpFieldsIO()
+            self.outfile = FieldsIO.fromFile(self.filename)
+            self.logger.info(
+                f'Set up file {self.filename!r} for writing output. This file already contains solutions up to t={self.outfile.times[-1]:.4f}.'
+            )
+        else:
+            self.outfile = L.prob.getOutputFile(self.filename)
+            self.logger.info(f'Set up file {self.filename!r} for writing output.')
+
+            # write initial conditions
+            if L.time not in self.outfile.times:
+                self.outfile.addField(time=L.time, field=L.prob.processSolutionForOutput(L.u[0]))
+                self.logger.info(f'Written initial conditions at t={L.time:4f} to file')
+
+    def post_step(self, step, level_number):
+        if level_number > 0:
+            return None
+
+        L = step.levels[level_number]
+
+        if self.t_next_log == 0:
+            self.t_next_log = L.time + self.time_increment
+
+        if L.time + L.dt >= self.t_next_log and not step.status.restart:
+            value_exists = True in [abs(me - (L.time + L.dt)) < np.finfo(float).eps * 1000 for me in self.outfile.times]
+            if value_exists and not self.allow_overwriting:
+                raise DataError(f'Already have recorded data for time {L.time + L.dt} in this file!')
+            self.outfile.addField(time=L.time + L.dt, field=L.prob.processSolutionForOutput(L.uend))
+            self.logger.info(f'Written solution at t={L.time+L.dt:.4f} to file')
+            self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment
+
+    @classmethod
+    def load(cls, index):
+        data = {}
+        file = FieldsIO.fromFile(cls.filename)
+        file_entry = file.readField(idx=index)
+        data['u'] = file_entry[1]
+        data['t'] = file_entry[0]
+        return data
diff --git a/pySDC/implementations/problem_classes/TestEquation_0D.py b/pySDC/implementations/problem_classes/TestEquation_0D.py
index 4276e96e9..5262b165f 100644
--- a/pySDC/implementations/problem_classes/TestEquation_0D.py
+++ b/pySDC/implementations/problem_classes/TestEquation_0D.py
@@ -3,6 +3,7 @@ import scipy.sparse as nsp
 
 from pySDC.core.problem import Problem, WorkCounter
 from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
+from pySDC.helpers.fieldsIO import Scalar
 
 
 class testequation0d(Problem):
@@ -146,6 +147,15 @@ class testequation0d(Problem):
         me[:] = u_init * self.xp.exp((t - t_init) * self.lambdas)
         return me
 
+    def getOutputFile(self, fileName):
+        fOut = Scalar(np.complex128, fileName=fileName)
+        fOut.setHeader(self.lambdas.size)
+        fOut.initialize()
+        return fOut
+
+    def processSolutionForOutput(self, u):
+        return u.flatten()
+
 
 class test_equation_IMEX(Problem):
     dtype_f = imex_mesh
diff --git a/pySDC/implementations/problem_classes/generic_spectral.py b/pySDC/implementations/problem_classes/generic_spectral.py
index 4f0fa3426..a8c4bc260 100644
--- a/pySDC/implementations/problem_classes/generic_spectral.py
+++ b/pySDC/implementations/problem_classes/generic_spectral.py
@@ -2,6 +2,7 @@ from pySDC.core.problem import Problem, WorkCounter
 from pySDC.helpers.spectral_helper import SpectralHelper
 import numpy as np
 from pySDC.core.errors import ParameterError
+from pySDC.helpers.fieldsIO import Rectilinear
 
 
 class GenericSpectralLinear(Problem):
@@ -333,6 +334,30 @@ class GenericSpectralLinear(Problem):
 
             return sol
 
+    def setUpFieldsIO(self):
+        Rectilinear.setupMPI(
+            comm=self.comm,
+            iLoc=[me.start for me in self.local_slice],
+            nLoc=[me.stop - me.start for me in self.local_slice],
+        )
+
+    def getOutputFile(self, fileName):
+        self.setUpFieldsIO()
+
+        coords = [me.get_1dgrid() for me in self.spectral.axes]
+        assert np.allclose([len(me) for me in coords], self.spectral.global_shape[1:])
+
+        fOut = Rectilinear(np.float64, fileName=fileName)
+        fOut.setHeader(nVar=len(self.components), coords=coords)
+        fOut.initialize()
+        return fOut
+
+    def processSolutionForOutput(self, u):
+        if self.spectral_space:
+            return np.array(self.itransform(u).real)
+        else:
+            return np.array(u.real)
+
 
 def compute_residual_DAE(self, stage=''):
     """
diff --git a/pySDC/projects/GPU/configs/GS_configs.py b/pySDC/projects/GPU/configs/GS_configs.py
index 5fde8848d..8c590e879 100644
--- a/pySDC/projects/GPU/configs/GS_configs.py
+++ b/pySDC/projects/GPU/configs/GS_configs.py
@@ -36,7 +36,7 @@ class GrayScott(Config):
 
     def get_LogToFile(self, ranks=None):
         import numpy as np
-        from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
+        from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile
 
         LogToFile.path = f'{self.base_path}/data/'
         LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
diff --git a/pySDC/projects/GPU/configs/RBC_configs.py b/pySDC/projects/GPU/configs/RBC_configs.py
index 28cea340c..10076257d 100644
--- a/pySDC/projects/GPU/configs/RBC_configs.py
+++ b/pySDC/projects/GPU/configs/RBC_configs.py
@@ -31,7 +31,7 @@ class RayleighBenardRegular(Config):
 
     def get_LogToFile(self, ranks=None):
         import numpy as np
-        from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
+        from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile
 
         LogToFile.path = f'{self.base_path}/data/'
         LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
diff --git a/pySDC/projects/GPU/tests/test_configs.py b/pySDC/projects/GPU/tests/test_configs.py
index ac25d813b..ab27a2621 100644
--- a/pySDC/projects/GPU/tests/test_configs.py
+++ b/pySDC/projects/GPU/tests/test_configs.py
@@ -64,7 +64,7 @@ def test_run_experiment(restart_idx=0):
             return desc
 
         def get_LogToFile(self, ranks=None):
-            from pySDC.implementations.hooks.log_solution import LogToFileAfterXs as LogToFile
+            from pySDC.implementations.hooks.log_solution import LogToPickleFileAfterXS as LogToFile
 
             LogToFile.path = './data/'
             LogToFile.file_name = f'{self.get_path(ranks=ranks)}-solution'
diff --git a/pySDC/tests/test_hooks/test_log_to_file.py b/pySDC/tests/test_hooks/test_log_to_file.py
index d786b2110..f5eb06fe4 100644
--- a/pySDC/tests/test_hooks/test_log_to_file.py
+++ b/pySDC/tests/test_hooks/test_log_to_file.py
@@ -1,12 +1,22 @@
 import pytest
 
 
-def run(hook, Tend=0):
-    from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d
-    from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
+def run(hook, Tend=0, ODE=True, t0=0):
     from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+    from pySDC.helpers.fieldsIO import FieldsIO
 
-    level_params = {'dt': 1.0e-1}
+    if ODE:
+        from pySDC.implementations.problem_classes.TestEquation_0D import testequation0d as problem_class
+        from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit as sweeper_class
+
+        problem_params = {'u0': 1.0}
+    else:
+        from pySDC.implementations.problem_classes.RayleighBenard import RayleighBenard as problem_class
+        from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order as sweeper_class
+
+        problem_params = {'nx': 16, 'nz': 8, 'spectral_space': False}
+
+    level_params = {'dt': 1.0e-2}
 
     sweeper_params = {
         'num_nodes': 1,
@@ -15,71 +25,146 @@ def run(hook, Tend=0):
 
     description = {
         'level_params': level_params,
-        'sweeper_class': generic_implicit,
-        'problem_class': testequation0d,
+        'sweeper_class': sweeper_class,
+        'problem_class': problem_class,
         'sweeper_params': sweeper_params,
-        'problem_params': {},
+        'problem_params': problem_params,
         'step_params': {'maxiter': 1},
     }
 
     controller_params = {
         'hook_class': hook,
-        'logger_level': 30,
+        'logger_level': 15,
     }
     controller = controller_nonMPI(1, controller_params, description)
     if Tend > 0:
         prob = controller.MS[0].levels[0].prob
         u0 = prob.u_exact(0)
+        if t0 > 0:
+            u0[:] = hook.load(-1)['u']
 
-        _, stats = controller.run(u0, 0, Tend)
+        _, stats = controller.run(u0, t0, Tend)
         return u0, stats
 
 
 @pytest.mark.base
-def test_errors():
-    from pySDC.implementations.hooks.log_solution import LogToFile
+def test_errors_pickle():
+    from pySDC.implementations.hooks.log_solution import LogToPickleFile
     import os
 
     with pytest.raises(ValueError):
-        run(LogToFile)
+        run(LogToPickleFile)
 
-    LogToFile.path = os.getcwd()
-    run(LogToFile)
+    LogToPickleFile.path = os.getcwd()
+    run(LogToPickleFile)
 
     path = f'{os.getcwd()}/tmp'
-    LogToFile.path = path
-    run(LogToFile)
+    LogToPickleFile.path = path
+    run(LogToPickleFile)
     os.path.isdir(path)
 
     with pytest.raises(ValueError):
-        LogToFile.path = __file__
-        run(LogToFile)
+        LogToPickleFile.path = __file__
+        run(LogToPickleFile)
+
+
+@pytest.mark.base
+def test_errors_FieldsIO(tmpdir):
+    from pySDC.implementations.hooks.log_solution import LogToFile as hook
+    from pySDC.core.errors import DataError
+    import os
+
+    path = f'{tmpdir}/FieldsIO_test.pySDC'
+    hook.filename = path
+
+    run_kwargs = {'hook': hook, 'Tend': 0.2, 'ODE': True}
+
+    # create file
+    run(**run_kwargs)
+
+    # test that we cannot overwrite if we don't want to
+    hook.allow_overwriting = False
+    with pytest.raises(FileExistsError):
+        run(**run_kwargs)
+
+    # test that we can overwrite if we do want to
+    hook.allow_overwriting = True
+    run(**run_kwargs)
+
+    # test that we cannot add solutions at times that already exist
+    hook.allow_overwriting = False
+    with pytest.raises(DataError):
+        run(**run_kwargs, t0=0.1)
 
 
 @pytest.mark.base
-def test_logging():
-    from pySDC.implementations.hooks.log_solution import LogToFile, LogSolution
+@pytest.mark.parametrize('use_pickle', [True, False])
+def test_logging(tmpdir, use_pickle, ODE=True):
+    from pySDC.implementations.hooks.log_solution import LogToPickleFile, LogSolution, LogToFile
     from pySDC.helpers.stats_helper import get_sorted
     import os
     import pickle
     import numpy as np
 
-    path = f'{os.getcwd()}/tmp'
-    LogToFile.path = path
-    Tend = 2
+    path = tmpdir
+    Tend = 0.2
+
+    if use_pickle:
+        logging_hook = LogToPickleFile
+        LogToPickleFile.path = path
+    else:
+        logging_hook = LogToFile
+        logging_hook.filename = f'{path}/FieldsIO_test.pySDC'
 
-    u0, stats = run([LogToFile, LogSolution], Tend=Tend)
+    u0, stats = run([logging_hook, LogSolution], Tend=Tend, ODE=ODE)
     u = [(0.0, u0)] + get_sorted(stats, type='u')
 
     u_file = []
     for i in range(len(u)):
-        data = LogToFile.load(i)
+        data = logging_hook.load(i)
         u_file += [(data['t'], data['u'])]
 
     for us, uf in zip(u, u_file):
-        assert us[0] == uf[0]
-        assert np.allclose(us[1], uf[1])
+        assert us[0] == uf[0], 'time does not match'
+        assert np.allclose(us[1], uf[1]), 'solution does not match'
+
+
+@pytest.mark.base
+def test_restart(tmpdir, ODE=True):
+    from pySDC.implementations.hooks.log_solution import LogSolution, LogToFile
+    import numpy as np
+
+    Tend = 0.2
+
+    # run the whole thing
+    logging_hook = LogToFile
+    logging_hook.filename = f'{tmpdir}/file.pySDC'
+
+    _, _ = run([logging_hook], Tend=Tend, ODE=ODE)
+
+    u_continuous = []
+    for i in range(20):
+        data = logging_hook.load(i)
+        u_continuous += [(data['t'], data['u'])]
+
+    # run again with a restart in the middle
+    logging_hook.filename = f'{tmpdir}/file2.pySDC'
+    _, _ = run(logging_hook, Tend=0.1, ODE=ODE)
+    _, _ = run(logging_hook, Tend=0.2, t0=0.1, ODE=ODE)
+
+    u_restart = []
+    for i in range(20):
+        data = logging_hook.load(i)
+        u_restart += [(data['t'], data['u'])]
+
+    assert np.allclose([me[0] for me in u_restart], [me[0] for me in u_continuous]), 'Times don\'t match'
+    for u1, u2 in zip(u_restart, u_continuous):
+        assert np.allclose(u1[1], u2[1]), 'solution does not match'
 
 
-if __name__ == '__main__':
-    test_logging()
+@pytest.mark.mpi4py
+@pytest.mark.mpi(ranks=[1, 4])
+def test_loggingMPI(tmpdir, comm, mpi_ranks):
+    # `mpi_ranks` is a pytest fixture required by pytest-isolate-mpi. Do not remove.
+    tmpdir = comm.bcast(tmpdir)
+    test_logging(tmpdir, False, False)
-- 
GitLab