From 9937bd31b18113a8f5e5d7f94b9c51bc6a05d24f Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 6 Feb 2020 15:14:24 +0100 Subject: [PATCH] MLT can now handle hourly data for entire workflow --- src/model_modules/linear_model.py | 2 +- src/plotting/postprocessing_plotting.py | 3 ++- src/run_modules/experiment_setup.py | 5 +++-- src/run_modules/post_processing.py | 4 +++- src/run_modules/pre_processing.py | 2 +- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py index 17a9b232..d8c455b5 100644 --- a/src/model_modules/linear_model.py +++ b/src/model_modules/linear_model.py @@ -32,7 +32,7 @@ class OrdinaryLeastSquaredModel: def predict(self, data): data = sm.add_constant(self.reshape_xarray_to_numpy(data)) - return self.model.predict(data) + return np.atleast_2d(self.model.predict(data)) @staticmethod def reshape_xarray_to_numpy(data): diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index cd49ddd5..ece1ad97 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -63,7 +63,8 @@ class PlotMonthlySummary(RunEnvironment): data = xr.open_dataarray(file_name) data_cnn = data.sel(type="CNN").squeeze() - data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values] + if len(data_cnn.shape) > 1: + data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values] data_orig = data.sel(type="orig", ahead=1).squeeze() data_orig.coords["ahead"] = "orig" diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 9ecc421b..d2410de0 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -33,13 +33,13 @@ class ExperimentSetup(RunEnvironment): window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None): + experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily"): # create run framework super().__init__() # experiment setup - self._set_param("data_path", helpers.prepare_host()) + self._set_param("data_path", helpers.prepare_host(sampling=sampling)) self._set_param("trainable", trainable, default=False) self._set_param("fraction_of_training", fraction_of_train, default=0.8) @@ -72,6 +72,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("end", end, default="2017-12-31", scope="general") self._set_param("window_history_size", window_history_size, default=13) self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing") + self._set_param("sampling", sampling) # target self._set_param("target_var", target_var, default="o3") diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index e6f271ce..93cb27dc 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -140,7 +140,9 @@ class PostProcessing(RunEnvironment): def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method): tmp_ols = self.ols_model.predict(input_data) tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) - ols_prediction.values = np.swapaxes(np.expand_dims(tmp_ols, axis=1), 2, 0) + tmp_ols = np.expand_dims(tmp_ols, axis=1) + target_shape = ols_prediction.values.shape + ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method): diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 2a4632d5..c5b1c53f 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -13,7 +13,7 @@ from src.join import EmptyQueryResult DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", - "station_type", "overwrite_local_data", "start", "end"] + "station_type", "overwrite_local_data", "start", "end", "sampling"] class PreProcessing(RunEnvironment): -- GitLab