Skip to content
Snippets Groups Projects
Commit 2c03cf86 authored by lukas leufen's avatar lukas leufen
Browse files

corrected test

parent 8fb227b7
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!91WIP: Resolve "create sphinx docu"
Pipeline #35635 passed
......@@ -6,6 +6,7 @@ __author__ = 'Lukas Leufen'
__date__ = '2019-11-22'
import inspect
import logging
import types
from abc import ABC
from functools import wraps
......@@ -109,13 +110,14 @@ class AbstractDataStore(ABC):
"""Initialise by creating empty data store."""
self._store: Dict = {}
def set(self, name: str, obj: Any, scope: str) -> None:
def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None:
"""
Abstract method to add an object to the data store.
:param name: Name of object to store
:param obj: The object itself to be stored
:param scope: the scope / context of the object, under that the object is valid
:param log: log which objects are stored if enabled (default false)
"""
pass
......@@ -196,7 +198,7 @@ class AbstractDataStore(ABC):
return args
@CorrectScope
def set_from_dict(self, arg_dict: Dict, scope: str) -> None:
def set_from_dict(self, arg_dict: Dict, scope: str, log: bool = False) -> None:
"""
Store multiple objects from dictionary under same `scope`.
......@@ -205,9 +207,10 @@ class AbstractDataStore(ABC):
:param arg_dict: updates for the data store, provided as key value pairs
:param scope: scope to store updates
:param log: log which objects are stored if enabled (default false)
"""
for (k, v) in arg_dict.items():
self.set(k, v, scope)
self.set(k, v, scope, log=log)
class DataStoreByVariable(AbstractDataStore):
......@@ -232,7 +235,7 @@ class DataStoreByVariable(AbstractDataStore):
"""
@CorrectScope
def set(self, name: str, obj: Any, scope: str) -> None:
def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None:
"""
Store an object `obj` with given `name` under `scope`.
......@@ -241,11 +244,14 @@ class DataStoreByVariable(AbstractDataStore):
:param name: Name of object to store
:param obj: The object itself to be stored
:param scope: the scope / context of the object, under that the object is valid
:param log: log which objects are stored if enabled (default false)
"""
# open new variable related store with `name` as key if not existing
if name not in self._store.keys():
self._store[name] = {}
self._store[name][scope] = obj
if log:
logging.debug(f"set: {name}({scope})={obj}")
@CorrectScope
def get(self, name: str, scope: str) -> Any:
......@@ -401,7 +407,7 @@ class DataStoreByScope(AbstractDataStore):
"""
@CorrectScope
def set(self, name: str, obj: Any, scope: str) -> None:
def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None:
"""
Store an object `obj` with given `name` under `scope`.
......@@ -410,10 +416,13 @@ class DataStoreByScope(AbstractDataStore):
:param name: Name of object to store
:param obj: The object itself to be stored
:param scope: the scope / context of the object, under that the object is valid
:param log: log which objects are stored if enabled (default false)
"""
if scope not in self._store.keys():
self._store[scope] = {}
self._store[scope][name] = obj
if log:
logging.debug(f"set: {name}({scope})={obj}")
@CorrectScope
def get(self, name: str, scope: str) -> Any:
......
......@@ -137,7 +137,7 @@ class ModelSetup(RunEnvironment):
def get_model_settings(self):
"""Load all model settings and store in data store."""
model_settings = self.model.get_settings()
self.data_store.set_from_dict(model_settings, self.scope)
self.data_store.set_from_dict(model_settings, self.scope, log=True)
self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
self.data_store.set("model_name", self.model_name, self.scope)
......
......@@ -16,7 +16,7 @@ class TestModelSetup:
def setup(self):
obj = object.__new__(ModelSetup)
super(ModelSetup, obj).__init__()
obj.scope = "general.modeltest"
obj.scope = "general.model"
obj.model = None
obj.callbacks_name = "placeholder_%s_str.pickle"
obj.data_store.set("lr_decay", "dummy_str", "general.model")
......@@ -58,24 +58,25 @@ class TestModelSetup:
return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))
def test_set_callbacks(self, setup):
assert "general.modeltest" not in setup.data_store.search_name("callbacks")
assert "general.model" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
assert "general.modeltest" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.modeltest")
assert "general.model" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 3
def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model")
assert "general.modeltest" not in setup.data_store.search_name("callbacks")
assert "general.model" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest")
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
assert len(callbacks.get_callbacks()) == 2
with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay")
def test_get_model_settings(self, setup_with_model):
setup_with_model.scope = "model_test"
with pytest.raises(EmptyScope):
self.current_scope_as_set(setup_with_model) # will fail because scope is not created
setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment