Skip to content
Snippets Groups Projects

Resolve "release v1.4.0"

Merged Ghost User requested to merge release_v1.4.0 into master
7 files
+ 24
7
Compare changes
  • Side-by-side
  • Inline
Files
7
@@ -6,6 +6,7 @@ import logging
@@ -6,6 +6,7 @@ import logging
import os
import os
import sys
import sys
from typing import Union, Dict, Any, List, Callable
from typing import Union, Dict, Any, List, Callable
 
from dill.source import getsource
from mlair.configuration import path_config
from mlair.configuration import path_config
from mlair import helpers
from mlair import helpers
@@ -217,7 +218,7 @@ class ExperimentSetup(RunEnvironment):
@@ -217,7 +218,7 @@ class ExperimentSetup(RunEnvironment):
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
max_number_multiprocessing: int = None, **kwargs):
max_number_multiprocessing: int = None, start_script: Union[Callable, str] = None, **kwargs):
# create run framework
# create run framework
super().__init__()
super().__init__()
@@ -366,6 +367,10 @@ class ExperimentSetup(RunEnvironment):
@@ -366,6 +367,10 @@ class ExperimentSetup(RunEnvironment):
# set model architecture class
# set model architecture class
self._set_param("model_class", model, VanillaModel)
self._set_param("model_class", model, VanillaModel)
 
# store starting script if provided
 
if start_script is not None:
 
self._store_start_script(start_script, experiment_path)
 
# set remaining kwargs
# set remaining kwargs
if len(kwargs) > 0:
if len(kwargs) > 0:
for k, v in kwargs.items():
for k, v in kwargs.items():
@@ -387,6 +392,18 @@ class ExperimentSetup(RunEnvironment):
@@ -387,6 +392,18 @@ class ExperimentSetup(RunEnvironment):
logging.debug(f"set experiment attribute: {param}({scope})={value}")
logging.debug(f"set experiment attribute: {param}({scope})={value}")
return value
return value
 
@staticmethod
 
def _store_start_script(start_script, store_path):
 
out_file = os.path.join(store_path, "start_script.txt")
 
if isinstance(start_script, Callable):
 
with open(out_file, "w") as fh:
 
fh.write(getsource(start_script))
 
if isinstance(start_script, str):
 
with open(start_script, 'r') as f:
 
with open(out_file, "w") as out:
 
for line in (f.readlines()):
 
print(line, end='', file=out)
 
def _compare_variables_and_statistics(self):
def _compare_variables_and_statistics(self):
"""
"""
Compare variables and statistics.
Compare variables and statistics.
Loading