Skip to content
Snippets Groups Projects
Commit d31cfad6 authored by lukas leufen's avatar lukas leufen
Browse files

add time tracking for some potential bottlenecks, run.py can now run without...

add time tracking for some potential bottlenecks, run.py can now run without training AND model_setup, if a model was stored locally
parent e63014b3
No related branches found
No related tags found
2 merge requests!37include new development,!30Lukas issue037 feat run without training
Pipeline #29020 passed
...@@ -18,6 +18,7 @@ from src import statistics ...@@ -18,6 +18,7 @@ from src import statistics
from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.plotting.postprocessing_plotting import plot_conditional_quantiles
from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore
from src.datastore import NameNotFoundInDataStore from src.datastore import NameNotFoundInDataStore
from src.helpers import TimeTracking
class PostProcessing(RunEnvironment): class PostProcessing(RunEnvironment):
...@@ -26,7 +27,7 @@ class PostProcessing(RunEnvironment): ...@@ -26,7 +27,7 @@ class PostProcessing(RunEnvironment):
super().__init__() super().__init__()
self.model: keras.Model = self._load_model() self.model: keras.Model = self._load_model()
self.ols_model = None self.ols_model = None
self.batch_size: int = self.data_store.get("batch_size", "general.model") self.batch_size: int = self.data_store.get_default("batch_size", "general.model", 64)
self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
self.train_data: DataGenerator = self.data_store.get("generator", "general.train") self.train_data: DataGenerator = self.data_store.get("generator", "general.train")
...@@ -36,8 +37,14 @@ class PostProcessing(RunEnvironment): ...@@ -36,8 +37,14 @@ class PostProcessing(RunEnvironment):
self._run() self._run()
def _run(self): def _run(self):
self.train_ols_model() with TimeTracking():
preds_for_all_stations = self.make_prediction() self.train_ols_model()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
with TimeTracking():
preds_for_all_stations = self.make_prediction()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
self.skill_scores = self.calculate_skill_scores() self.skill_scores = self.calculate_skill_scores()
self.plot() self.plot()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment