Skip to content
Snippets Groups Projects
Commit 4cd8fe3c authored by lukas leufen's avatar lukas leufen
Browse files

post processing uses plot_monthly_summary()

parent 7211a38a
No related branches found
No related tags found
2 merge requests!37include new development,!27Lukas issue032 feat plotting postprocessing
Pipeline #28023 passed
...@@ -30,7 +30,8 @@ def main(parser_args): ...@@ -30,7 +30,8 @@ def main(parser_args):
if __name__ == "__main__": if __name__ == "__main__":
formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]'
logging.basicConfig(format=formatter, level=logging.INFO) # logging.basicConfig(format=formatter, level=logging.INFO)
logging.basicConfig(format=formatter, level=logging.DEBUG)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
......
...@@ -12,10 +12,12 @@ import statsmodels.api as sm ...@@ -12,10 +12,12 @@ import statsmodels.api as sm
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src import statistics from src import statistics
from src import helpers from src import helpers
from src.helpers import TimeTracking from src.helpers import TimeTracking
from src.plotting.postprocessing_plotting import plot_monthly_summary
class PostProcessing(RunEnvironment): class PostProcessing(RunEnvironment):
...@@ -25,14 +27,23 @@ class PostProcessing(RunEnvironment): ...@@ -25,14 +27,23 @@ class PostProcessing(RunEnvironment):
self.model = self.data_store.get("best_model", "general") self.model = self.data_store.get("best_model", "general")
self.ols_model = None self.ols_model = None
self.batch_size = self.data_store.get("batch_size", "general.model") self.batch_size = self.data_store.get("batch_size", "general.model")
self.test_data = self.data_store.get("generator", "general.test") self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size) self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
self.train_data = self.data_store.get("generator", "general.train") self.train_data: DataGenerator = self.data_store.get("generator", "general.train")
self.plot_path = self.data_store.get("plot_path", "general")
self._run() self._run()
def _run(self): def _run(self):
self.train_ols_model() self.train_ols_model()
preds_for_all_stations = self.make_prediction() preds_for_all_stations = self.make_prediction()
self.plot()
def plot(self):
path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
target_var = self.data_store.get("target_var", "general")
plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var,
plot_folder=self.plot_path)
def calculate_test_score(self): def calculate_test_score(self):
test_score = self.model.evaluate(generator=self.test_data_distributed.distribute_on_batches(), test_score = self.model.evaluate(generator=self.test_data_distributed.distribute_on_batches(),
...@@ -50,6 +61,7 @@ class PostProcessing(RunEnvironment): ...@@ -50,6 +61,7 @@ class PostProcessing(RunEnvironment):
self.ols_model = OrdinaryLeastSquaredModel(self.train_data) self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
def make_prediction(self, freq="1D"): def make_prediction(self, freq="1D"):
logging.debug("start make_prediction")
nn_prediction_all_stations = [] nn_prediction_all_stations = []
for i, v in enumerate(self.test_data): for i, v in enumerate(self.test_data):
data = self.test_data.get_data_generator(i) data = self.test_data.get_data_generator(i)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment