diff --git a/src/datastore.py b/src/datastore.py index d9f844ff97acb3f5c6600205f91100219d9c53e6..4b70908dd94cbf6062f31b5da0c3daf6ec92352a 100644 --- a/src/datastore.py +++ b/src/datastore.py @@ -3,6 +3,9 @@ __date__ = '2019-11-22' from abc import ABC +from functools import wraps +import inspect +import types from typing import Any, List, Tuple, Dict @@ -27,6 +30,39 @@ class EmptyScope(Exception): pass +class CorrectScope: + + def __init__(self, func): + wraps(func)(self) + + def __call__(self, *args, **kwargs): + f_arg = inspect.getfullargspec(self.__wrapped__) + pos_scope = f_arg.args.index("scope") + if len(args) < (len(f_arg.args) - len(f_arg.defaults or "")): + new_arg = kwargs.pop("scope", "general") or "general" + args = self.update_tuple(args, new_arg, pos_scope) + else: + args = self.update_tuple(args, self.correct(args[pos_scope]), pos_scope, update=True) + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + if instance is None: + return self + else: + return types.MethodType(self, instance) + + @staticmethod + def correct(arg: str): + if not arg.startswith("general"): + arg = "general." + arg + return arg + + @staticmethod + def update_tuple(t, new, ind, update=False): + t_new = (*t[:ind], new, *t[ind + update:]) + return t_new + + class AbstractDataStore(ABC): """ @@ -119,6 +155,7 @@ class DataStoreByVariable(AbstractDataStore): <scope3>: value """ + @CorrectScope def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are @@ -132,6 +169,7 @@ class DataStoreByVariable(AbstractDataStore): self._store[name] = {} self._store[name][scope] = obj + @CorrectScope def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative @@ -144,6 +182,7 @@ class DataStoreByVariable(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + @CorrectScope def get_default(self, name: str, scope: str, default: Any) -> Any: """ Same functionality like the standard get method. But this method adds a default argument that is returned if no @@ -160,6 +199,7 @@ class DataStoreByVariable(AbstractDataStore): except (NameNotFoundInDataStore, NameNotFoundInScope): return default + @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] @@ -183,6 +223,7 @@ class DataStoreByVariable(AbstractDataStore): """ return sorted(self._store[name] if name in self._store.keys() else []) + @CorrectScope def search_scope(self, scope: str, current_scope_only=True, return_all=False) -> List[str or Tuple]: """ Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes @@ -259,6 +300,7 @@ class DataStoreByScope(AbstractDataStore): <variable3>: value """ + @CorrectScope def set(self, name: str, obj: Any, scope: str) -> None: """ Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are @@ -271,6 +313,7 @@ class DataStoreByScope(AbstractDataStore): self._store[scope] = {} self._store[scope][name] = obj + @CorrectScope def get(self, name: str, scope: str) -> Any: """ Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative @@ -283,6 +326,7 @@ class DataStoreByScope(AbstractDataStore): """ return self._stride_through_scopes(name, scope)[2] + @CorrectScope def get_default(self, name: str, scope: str, default: Any) -> Any: """ Same functionality like the standard get method. But this method adds a default argument that is returned if no @@ -299,6 +343,7 @@ class DataStoreByScope(AbstractDataStore): except (NameNotFoundInDataStore, NameNotFoundInScope): return default + @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): local_scope = scope.rsplit(".", maxsplit=depth)[0] @@ -326,6 +371,7 @@ class DataStoreByScope(AbstractDataStore): keys.append(key) return sorted(keys) + @CorrectScope def search_scope(self, scope: str, current_scope_only: bool = True, return_all: bool = False) -> List[str or Tuple]: """ Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes diff --git a/test/test_datastore.py b/test/test_datastore.py index 9fcb319f51954b365c59274a4a9744f093e155f1..6f454d539e8df49344dd4e8d8440f1abebf815fd 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -68,7 +68,7 @@ class TestDataStoreByVariable: ds.set("number", 3, "general") assert ds.get_default("number", "general", 45) == 3 assert ds.get_default("number", "general.sub", 45) == 3 - assert ds.get_default("number", "other", 45) == 45 + assert ds.get_default("other", 45) == 45 def test_search(self, ds): ds.set("number", 22, "general") @@ -161,6 +161,19 @@ class TestDataStoreByVariable: assert ds.get("tester1", "general.sub") == 111 assert ds.get("tester3", "general.sub") == 21 + def test_no_scope_given(self, ds): + ds.set("tester", 34) + assert ds._store["tester"]["general"] == 34 + assert ds.get("tester") == 34 + assert ds.get("tester", "sub") == 34 + ds.set("tester", 99, "sub") + assert ds.list_all_scopes() == ["general", "general.sub"] + assert ds.get_default("test2", 4) == 4 + assert ds.get_default("tester", "sub", 4) == 99 + ds.set("test2", 4) + assert sorted(ds.search_scope(current_scope_only=False)) == sorted(["tester", "test2"]) + assert ds.search_scope("sub", current_scope_only=True) == ["tester"] + class TestDataStoreByScope: @@ -206,7 +219,7 @@ class TestDataStoreByScope: ds.set("number", 3, "general") assert ds.get_default("number", "general", 45) == 3 assert ds.get_default("number", "general.sub", 45) == 3 - assert ds.get_default("number", "other", 45) == 45 + assert ds.get_default("other", "other", 45) == 45 def test_search(self, ds): ds.set("number", 22, "general") @@ -297,4 +310,17 @@ class TestDataStoreByScope: assert ds.get("tester3", "general") == 21 ds.set_args_from_dict({"tester1": 111}, "general.sub") assert ds.get("tester1", "general.sub") == 111 - assert ds.get("tester3", "general.sub") == 21 \ No newline at end of file + assert ds.get("tester3", "general.sub") == 21 + + def test_no_scope_given(self, ds): + ds.set("tester", 34) + assert ds._store["general"]["tester"] == 34 + assert ds.get("tester") == 34 + assert ds.get("tester", "sub") == 34 + ds.set("tester", 99, "sub") + assert ds.list_all_scopes() == ["general", "general.sub"] + assert ds.get_default("test2", 4) == 4 + assert ds.get_default("tester", "sub", 4) == 99 + ds.set("test2", 4) + assert sorted(ds.search_scope(current_scope_only=False)) == sorted(["tester", "test2"]) + assert ds.search_scope("sub", current_scope_only=True) == ["tester"]