diff --git a/src/helpers/datastore.py b/src/helpers/datastore.py index a540d6f864775a2b333ecd544d507f28244b137c..cd852067320b3ecb497d19311fbae18d1622d986 100644 --- a/src/helpers/datastore.py +++ b/src/helpers/datastore.py @@ -43,6 +43,9 @@ class CorrectScope: def __init__(self, func): """Construct decorator.""" + setattr(self, "wrapper", func) + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ wraps(func)(self) def __call__(self, *args, **kwargs): @@ -59,7 +62,7 @@ class CorrectScope: args = self.update_tuple(args, new_arg, pos_scope) else: args = self.update_tuple(args, args[pos_scope], pos_scope, update=True) - return self.__wrapped__(*args, **kwargs) + return self.wrapper(*args, **kwargs) def __get__(self, instance, cls): """Create bound method object and supply self argument to the decorated method.""" @@ -97,6 +100,41 @@ class CorrectScope: return t_new +class TrackParameter: + + def __init__(self, func): + """Construct decorator.""" + wraps(func)(self) + + def __call__(self, *args, **kwargs): + """ + Call method of decorator. + """ + self.track(*args) + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + """Create bound method object and supply self argument to the decorated method.""" + return types.MethodType(self, instance) + + def track(self, tracker_obj, *args): + name, obj, scope = self._decrypt_args(*args) + logging.debug(f"{self.__wrapped__.__name__}: {name}({scope})={obj}") + tracker = tracker_obj.tracker + new_entry = [(self.__wrapped__.__name__, scope, obj)] + if tracker.get(name): + tracker[name].append(new_entry) + else: + tracker[name] = new_entry + + @staticmethod + def _decrypt_args(*args): + if len(args) == 2: + return args[0], None, args[1] + else: + return args + + class AbstractDataStore(ABC): """ Abstract data store for all settings for the experiment workflow. @@ -106,6 +144,8 @@ class AbstractDataStore(ABC): adjustments. """ + tracker = {} + def __init__(self): """Initialise by creating empty data store.""" self._store: Dict = {} @@ -235,6 +275,7 @@ class DataStoreByVariable(AbstractDataStore): """ @CorrectScope + @TrackParameter def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ Store an object `obj` with given `name` under `scope`. @@ -254,6 +295,7 @@ class DataStoreByVariable(AbstractDataStore): logging.debug(f"set: {name}({scope})={obj}") @CorrectScope + @TrackParameter def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index f5fc1f0fd627120f266b419b150eeb85b62c7389..823995181807a25b3e1759bdb77c639e16f8fd64 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -50,7 +50,7 @@ class ModelSetup(RunEnvironment): * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] Creates - * plot of model architecture in `<model_name>.pdf` + * plot of model architecture `<model_name>.pdf` """