From f0e21c8005b335a8b439b08c488fb31063993bc0 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 16 May 2023 11:36:55 +0200 Subject: [PATCH] use apriori_file to load external apriori information --- mlair/run_modules/pre_processing.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 1fa2a3c5..bfb6dd4f 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -64,6 +64,7 @@ class PreProcessing(RunEnvironment): if snapshot_load_path is None: stations = self.data_store.get("stations") data_handler = self.data_store.get("data_handler") + self._load_apriori() _, valid_stations = self.validate_station(data_handler, stations, "preprocessing") # , store_processed_data=False) if len(valid_stations) == 0: @@ -318,6 +319,30 @@ class PreProcessing(RunEnvironment): attrs[k] = dict(attrs.get(k, {}), **{station: v}) for k, v in attrs.items(): self.data_store.set(k, v) + self._store_apriori() + + def _store_apriori(self): + apriori = self.data_store.get_default("apriori", default=None) + if apriori: + experiment_path = self.data_store.get("experiment_path") + path = os.path.join(experiment_path, "data", "apriori") + store_file = os.path.join(path, "apriori.pickle") + if not os.path.exists(path): + path_config.check_path_and_create(path) + with open(store_file, "wb") as f: + dill.dump(apriori, f, protocol=4) + logging.debug(f"Store apriori options locally for later use at: {store_file}") + + def _load_apriori(self): + if self.data_store.get_default("apriori", default=None) is None: + apriori_file = self.data_store.get_default("apriori_file", None) + if apriori_file is not None: + if os.path.exists(apriori_file): + logging.info(f"use apriori data from given file: {apriori_file}") + with open(apriori_file, "rb") as pickle_file: + self.data_store.set("apriori", dill.load(pickle_file)) + else: + logging.info(f"cannot load apriori file: {apriori_file}. Use fresh calculation from data.") def transformation(self, data_handler: AbstractDataHandler, stations): calculate_fresh_transformation = self.data_store.get_default("calculate_fresh_transformation", True) -- GitLab