diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 8036413c8aefc3f70f8c24e59812c1a3b3324de1..209859c1ff38efe2667c918aa5b79c96f2524be0 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -6,6 +6,7 @@ import logging import os import sys from typing import Union, Dict, Any, List, Callable +from dill.source import getsource from mlair.configuration import path_config from mlair import helpers @@ -217,7 +218,7 @@ class ExperimentSetup(RunEnvironment): hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, data_origin: Dict = None, competitors: list = None, competitor_path: str = None, use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None, - max_number_multiprocessing: int = None, **kwargs): + max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, **kwargs): # create run framework super().__init__() @@ -366,6 +367,10 @@ class ExperimentSetup(RunEnvironment): # set model architecture class self._set_param("model_class", model, VanillaModel) + # store starting script if provided + if start_script is not None: + self._store_start_script(start_script, experiment_path) + # set remaining kwargs if len(kwargs) > 0: for k, v in kwargs.items(): @@ -387,6 +392,18 @@ class ExperimentSetup(RunEnvironment): logging.debug(f"set experiment attribute: {param}({scope})={value}") return value + @staticmethod + def _store_start_script(start_script, store_path): + out_file = os.path.join(store_path, "start_script.txt") + if isinstance(start_script, Callable): + with open(out_file, "w") as fh: + fh.write(getsource(start_script)) + if isinstance(start_script, str): + with open(start_script, 'r') as f: + with open(out_file, "w") as out: + for line in (f.readlines()): + print(line, end='', file=out) + def _compare_variables_and_statistics(self): """ Compare variables and statistics.