diff --git a/run.py b/run.py index 03eda04280a19e5c2bb9f1743c40f07e9e3fd2cc..71244fb9d15f594ac3ffbce60341d5c8dcb15f03 100644 --- a/run.py +++ b/run.py @@ -30,7 +30,8 @@ def main(parser_args): if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' - logging.basicConfig(format=formatter, level=logging.INFO) + # logging.basicConfig(format=formatter, level=logging.INFO) + logging.basicConfig(format=formatter, level=logging.DEBUG) parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 35d93dcbd932d1c298c0744fcd0205697576bb4c..5d4c805a53707612535dd8a92154440227d03bc4 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -12,10 +12,12 @@ import statsmodels.api as sm from src.run_modules.run_environment import RunEnvironment from src.data_handling.data_distributor import Distributor +from src.data_handling.data_generator import DataGenerator from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics from src import helpers from src.helpers import TimeTracking +from src.plotting.postprocessing_plotting import plot_monthly_summary class PostProcessing(RunEnvironment): @@ -25,14 +27,23 @@ class PostProcessing(RunEnvironment): self.model = self.data_store.get("best_model", "general") self.ols_model = None self.batch_size = self.data_store.get("batch_size", "general.model") - self.test_data = self.data_store.get("generator", "general.test") + self.test_data: DataGenerator = self.data_store.get("generator", "general.test") self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) - self.train_data = self.data_store.get("generator", "general.train") + self.train_data: DataGenerator = self.data_store.get("generator", "general.train") + self.plot_path = self.data_store.get("plot_path", "general") self._run() def _run(self): self.train_ols_model() preds_for_all_stations = self.make_prediction() + self.plot() + + def plot(self): + path = self.data_store.get("forecast_path", "general") + window_lead_time = self.data_store.get("window_lead_time", "general") + target_var = self.data_store.get("target_var", "general") + plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, + plot_folder=self.plot_path) def calculate_test_score(self): test_score = self.model.evaluate(generator=self.test_data_distributed.distribute_on_batches(), @@ -50,6 +61,7 @@ class PostProcessing(RunEnvironment): self.ols_model = OrdinaryLeastSquaredModel(self.train_data) def make_prediction(self, freq="1D"): + logging.debug("start make_prediction") nn_prediction_all_stations = [] for i, v in enumerate(self.test_data): data = self.test_data.get_data_generator(i)