diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index a9695064e1d2864d4a367a297ba94cc404d46538..e6f271ce3cc6cf2548ff5b06ba40e2fd509f8c8d 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -18,6 +18,7 @@ from src import statistics from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore from src.datastore import NameNotFoundInDataStore +from src.helpers import TimeTracking class PostProcessing(RunEnvironment): @@ -26,7 +27,7 @@ class PostProcessing(RunEnvironment): super().__init__() self.model: keras.Model = self._load_model() self.ols_model = None - self.batch_size: int = self.data_store.get("batch_size", "general.model") + self.batch_size: int = self.data_store.get_default("batch_size", "general.model", 64) self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.train_data: DataGenerator = self.data_store.get("generator", "general.train") @@ -36,8 +37,14 @@ class PostProcessing(RunEnvironment): self._run() def _run(self): - self.train_ols_model() - preds_for_all_stations = self.make_prediction() + with TimeTracking(): + self.train_ols_model() + logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " + "skip make_prediction() whenever it is possible to save time.") + with TimeTracking(): + preds_for_all_stations = self.make_prediction() + logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " + "skip make_prediction() whenever it is possible to save time.") self.skill_scores = self.calculate_skill_scores() self.plot()