Skip to content
Snippets Groups Projects
Commit 6917b04a authored by leufen1's avatar leufen1
Browse files

MLAir can store parameters during preprocessing from train subset when using store_attributes

parent 4c596cc8
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #66871 passed
...@@ -11,6 +11,7 @@ from mlair.helpers import remove_items ...@@ -11,6 +11,7 @@ from mlair.helpers import remove_items
class AbstractDataHandler: class AbstractDataHandler:
_requirements = [] _requirements = []
_store_attributes = []
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
...@@ -32,6 +33,31 @@ class AbstractDataHandler: ...@@ -32,6 +33,31 @@ class AbstractDataHandler:
list_of_args = arg_spec.args + arg_spec.kwonlyargs list_of_args = arg_spec.args + arg_spec.kwonlyargs
return remove_items(list_of_args, ["self"] + list(args)) return remove_items(list_of_args, ["self"] + list(args))
@classmethod
def store_attributes(cls):
"""
Let MLAir know that some data should be stored in the data store. This is used for calculations on the train
subset that should be applied to validation and test subset.
To work properly, add a class variable cls._store_attributes to your data handler. If your custom data handler
is constructed on different data handlers (e.g. like the DefaultDataHandler), it is required to overwrite the
get_store_attributs method in addition to return attributes from the corresponding subclasses. This is not
required, if only attributes from the main class are to be returned.
Note, that MLAir will store these attributes with the data handler's identification. This depends on the custom
data handler setting. When loading an attribute from the data handler, it is therefore required to extract the
right information by using the class identification. In case of the DefaultDataHandler this can be achieved to
convert all keys of the attribute to string and compare these with the station parameter.
"""
return list(set(cls._store_attributes))
def get_store_attributes(self):
"""Returns all attribute names and values that are indicated by the store_attributes method."""
attr_dict = {}
for attr in self.store_attributes():
attr_dict[attr] = self.__getattribute__(attr)
return attr_dict
@classmethod @classmethod
def transformation(cls, *args, **kwargs): def transformation(cls, *args, **kwargs):
return None return None
......
...@@ -33,6 +33,7 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -33,6 +33,7 @@ class DefaultDataHandler(AbstractDataHandler):
from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation
_requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"])
_store_attributes = data_handler.store_attributes()
DEFAULT_ITER_DIM = "Stations" DEFAULT_ITER_DIM = "Stations"
DEFAULT_TIME_DIM = "datetime" DEFAULT_TIME_DIM = "datetime"
...@@ -93,6 +94,16 @@ class DefaultDataHandler(AbstractDataHandler): ...@@ -93,6 +94,16 @@ class DefaultDataHandler(AbstractDataHandler):
logging.debug(f"save pickle data to {self._save_file}") logging.debug(f"save pickle data to {self._save_file}")
self._reset_data() self._reset_data()
def get_store_attributes(self):
attr_dict = {}
for attr in self.store_attributes():
try:
val = self.__getattribute__(attr)
except AttributeError:
val = self.id_class.__getattribute__(attr)
attr_dict[attr] = val
return attr_dict
@staticmethod @staticmethod
def _force_dask_computation(data): def _force_dask_computation(data):
try: try:
......
...@@ -268,8 +268,22 @@ class PreProcessing(RunEnvironment): ...@@ -268,8 +268,22 @@ class PreProcessing(RunEnvironment):
logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/"
f"{len(set_stations)} valid stations.") f"{len(set_stations)} valid stations.")
if set_name == "train":
self.store_data_handler_attributes(data_handler, collection)
return collection, valid_stations return collection, valid_stations
def store_data_handler_attributes(self, data_handler, collection):
store_attributes = data_handler.store_attributes()
if len(store_attributes) > 0:
logging.info("store data requested by the data handler")
attrs = {}
for dh in collection:
station = str(dh)
for k, v in dh.get_store_attributes().items():
attrs[k] = dict(attrs.get(k, {}), **{station: v})
for k, v in attrs.items():
self.data_store.set(k, v)
def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None, def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None,
store_processed_data=True): store_processed_data=True):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment