diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 1fa2a3c58b5e6fe122cd123804708a26beed68d8..bfb6dd4f9829639a4c8d3d9a78a8899e95f325ea 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -64,6 +64,7 @@ class PreProcessing(RunEnvironment): if snapshot_load_path is None: stations = self.data_store.get("stations") data_handler = self.data_store.get("data_handler") + self._load_apriori() _, valid_stations = self.validate_station(data_handler, stations, "preprocessing") # , store_processed_data=False) if len(valid_stations) == 0: @@ -318,6 +319,30 @@ class PreProcessing(RunEnvironment): attrs[k] = dict(attrs.get(k, {}), **{station: v}) for k, v in attrs.items(): self.data_store.set(k, v) + self._store_apriori() + + def _store_apriori(self): + apriori = self.data_store.get_default("apriori", default=None) + if apriori: + experiment_path = self.data_store.get("experiment_path") + path = os.path.join(experiment_path, "data", "apriori") + store_file = os.path.join(path, "apriori.pickle") + if not os.path.exists(path): + path_config.check_path_and_create(path) + with open(store_file, "wb") as f: + dill.dump(apriori, f, protocol=4) + logging.debug(f"Store apriori options locally for later use at: {store_file}") + + def _load_apriori(self): + if self.data_store.get_default("apriori", default=None) is None: + apriori_file = self.data_store.get_default("apriori_file", None) + if apriori_file is not None: + if os.path.exists(apriori_file): + logging.info(f"use apriori data from given file: {apriori_file}") + with open(apriori_file, "rb") as pickle_file: + self.data_store.set("apriori", dill.load(pickle_file)) + else: + logging.info(f"cannot load apriori file: {apriori_file}. Use fresh calculation from data.") def transformation(self, data_handler: AbstractDataHandler, stations): calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True)