diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 419db059a58beeb4ed7e3e198e41b565f8dc7d25..c020a4134f79fe8a7446f45a791eda9057dc6885 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -11,6 +11,7 @@ from mlair.helpers import remove_items class AbstractDataHandler: _requirements = [] + _store_attributes = [] def __init__(self, *args, **kwargs): pass @@ -32,6 +33,31 @@ class AbstractDataHandler: list_of_args = arg_spec.args + arg_spec.kwonlyargs return remove_items(list_of_args, ["self"] + list(args)) + @classmethod + def store_attributes(cls): + """ + Let MLAir know that some data should be stored in the data store. This is used for calculations on the train + subset that should be applied to validation and test subset. + + To work properly, add a class variable cls._store_attributes to your data handler. If your custom data handler + is constructed on different data handlers (e.g. like the DefaultDataHandler), it is required to overwrite the + get_store_attributs method in addition to return attributes from the corresponding subclasses. This is not + required, if only attributes from the main class are to be returned. + + Note, that MLAir will store these attributes with the data handler's identification. This depends on the custom + data handler setting. When loading an attribute from the data handler, it is therefore required to extract the + right information by using the class identification. In case of the DefaultDataHandler this can be achieved to + convert all keys of the attribute to string and compare these with the station parameter. + """ + return list(set(cls._store_attributes)) + + def get_store_attributes(self): + """Returns all attribute names and values that are indicated by the store_attributes method.""" + attr_dict = {} + for attr in self.store_attributes(): + attr_dict[attr] = self.__getattribute__(attr) + return attr_dict + @classmethod def transformation(cls, *args, **kwargs): return None diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 11461ad77c3e910a897a9a1be48aef7cef45480a..73b6b53d3be657d5b14796e26c38ac1364d51112 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -33,6 +33,7 @@ class DefaultDataHandler(AbstractDataHandler): from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) + _store_attributes = data_handler.store_attributes() DEFAULT_ITER_DIM = "Stations" DEFAULT_TIME_DIM = "datetime" @@ -93,6 +94,16 @@ class DefaultDataHandler(AbstractDataHandler): logging.debug(f"save pickle data to {self._save_file}") self._reset_data() + def get_store_attributes(self): + attr_dict = {} + for attr in self.store_attributes(): + try: + val = self.__getattribute__(attr) + except AttributeError: + val = self.id_class.__getattribute__(attr) + attr_dict[attr] = val + return attr_dict + @staticmethod def _force_dask_computation(data): try: diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 9d44ce0b0e8d7b0bac9c188c697a5e65ab67df4c..d50f6f9ab7abcbe61259af8a880f24b01290661b 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -268,8 +268,22 @@ class PreProcessing(RunEnvironment): logging.info(f"run for {t_outer} to check {len(set_stations)} station(s). Found {len(collection)}/" f"{len(set_stations)} valid stations.") + if set_name == "train": + self.store_data_handler_attributes(data_handler, collection) return collection, valid_stations + def store_data_handler_attributes(self, data_handler, collection): + store_attributes = data_handler.store_attributes() + if len(store_attributes) > 0: + logging.info("store data requested by the data handler") + attrs = {} + for dh in collection: + station = str(dh) + for k, v in dh.get_store_attributes().items(): + attrs[k] = dict(attrs.get(k, {}), **{station: v}) + for k, v in attrs.items(): + self.data_store.set(k, v) + def validate_station_old(self, data_handler: AbstractDataHandler, set_stations, set_name=None, store_processed_data=True): """