From 113e891a8f25ff95739bde5c104ed333d0116dda Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Thu, 27 Feb 2020 11:26:51 +0100 Subject: [PATCH] get labels and forecast in correct behaviour (station-wise) --- src/data_handling/bootstraps.py | 31 ++++++-- src/run_modules/post_processing.py | 124 +++++++++++++++++++---------- 2 files changed, 104 insertions(+), 51 deletions(-) diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 60fc55fb..3c6c2c57 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -31,10 +31,10 @@ class BootStrapGenerator: """ return len(self.orig_generator)*self.boots*len(self.variables) - def get_labels(self): - for (_, label) in self.orig_generator: - for _ in range(self.boots): - yield label + def get_labels(self, key): + _, label = self.orig_generator[key] + for _ in range(self.boots): + yield label def get_generator(self): """ @@ -57,10 +57,16 @@ class BootStrapGenerator: shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") boot_hist = boot_hist.combine_first(shuffled_var) boot_hist = boot_hist.sortby("variables") - self.bootstrap_meta.extend([var]*len_of_label) + self.bootstrap_meta.extend([[var, station]]*len_of_label) yield boot_hist, label return + def get_orig_prediction(self, path, file_name, prediction_name="CNN"): + file = os.path.join(path, file_name) + data = xr.open_dataarray(file) + for _ in range(self.boots): + yield data.sel(type=prediction_name).squeeze() + def load_boot_data(self, station): files = os.listdir(self.bootstrap_path) regex = re.compile(rf"{station}_\w*\.nc") @@ -90,16 +96,27 @@ class BootStraps(RunEnvironment): def get_boot_strap_generator_length(self): return self._boot_strap_generator.__len__() - def get_labels(self): + def get_labels(self, key): labels_list = [] chunks = None - for labels in self._boot_strap_generator.get_labels(): + for labels in self._boot_strap_generator.get_labels(key): if len(labels_list) == 0: chunks = (100, labels.data.shape[1]) labels_list.append(da.from_array(labels.data, chunks=chunks)) labels_out = da.concatenate(labels_list, axis=0) return labels_out.compute() + def get_orig_prediction(self, path, name): + labels_list = [] + chunks = None + for labels in self._boot_strap_generator.get_orig_prediction(path, name): + if len(labels_list) == 0: + chunks = (100, labels.data.shape[1]) + labels_list.append(da.from_array(labels.data, chunks=chunks)) + labels_out = da.concatenate(labels_list, axis=0) + labels_out = labels_out.compute() + return labels_out[~np.isnan(labels_out).any(axis=1), :] + def get_chunk_size(self): hist, _ = self.data[0] return (100, *hist.shape[1:], self.number_bootstraps) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 97f06812..0a791617 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -54,6 +54,9 @@ class PostProcessing(RunEnvironment): self.create_boot_straps() def create_boot_straps(self): + + # forecast + bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") @@ -62,20 +65,44 @@ class PostProcessing(RunEnvironment): bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), steps=bootstraps.get_boot_strap_generator_length()) bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) - length = sum(bootstrap_meta == bootstrap_meta[0]) - variables = np.unique(bootstrap_meta) - for boot in variables: - ind = (bootstrap_meta == boot) - sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) - tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) - logging.info(tmp.shape) - file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") - tmp.to_netcdf(file_name) - labels = bootstraps.get_labels().reshape((length, window_lead_time, 1)) - file_name = os.path.join(forecast_path, f"bootstraps_orig.nc") - orig = xr.DataArray(labels, coords=(range(length), range(window_lead_time), ["orig"]), dims=["index", "window", "boot"]) - logging.info(orig.shape) - orig.to_netcdf(file_name) + variables = np.unique(bootstrap_meta[:, 0]) + for station in np.unique(bootstrap_meta[:, 1]): + coords = None + for boot in variables: + ind = np.all(bootstrap_meta == [boot, station], axis=1) + length = sum(ind) + sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) + coords = (range(length), range(window_lead_time)) + tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "window", "type"]) + file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc") + tmp.to_netcdf(file_name) + labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1)) + file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") + labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "window", "type"]) + labels.to_netcdf(file_name) + + # file_name = os.path.join(forecast_path, f"bootstraps_orig.nc") + # orig = xr.open_dataarray(file_name) + + # calc skill scores + skill_scores = statistics.SkillScores(None) + score = {} + for station in np.unique(bootstrap_meta[:, 1]): + file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") + labels = xr.open_dataarray(file_name) + shape = labels.shape + orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape) + orig = xr.DataArray(orig, coords=(range(shape[0]), range(shape[1]), ["orig"]), dims=["index", "window", "type"]) + score[station] = {} + for boot in variables: + file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc") + boot_data = xr.open_dataarray(file_name) + boot_data = boot_data.combine_first(labels) + boot_data = boot_data.combine_first(orig) + score[station][boot] = skill_scores.general_skill_score(boot_data, forecast_name=boot, reference_name="orig") + + # plot + def _load_model(self): try: @@ -122,64 +149,72 @@ class PostProcessing(RunEnvironment): logging.debug("start make_prediction") for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) - - nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) input_data = data.get_transposed_history() # get scaling parameters mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) - # nn forecast - nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method) + for normalised in [True, False]: + # create empty arrays + nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) + + # nn forecast + nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) - # persistence - persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, - transformation_method) + # persistence + persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, + transformation_method, normalised) - # ols - ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method) + # ols + ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised) - # observation - observation = self._create_observation(data, None, mean, std, transformation_method) + # observation + observation = self._create_observation(data, None, mean, std, transformation_method, normalised) - # merge all predictions - full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency()) - all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']), - CNN=nn_prediction, - persi=persistence_prediction, - obs=observation, - OLS=ols_prediction) + # merge all predictions + full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency()) + all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']), + CNN=nn_prediction, + persi=persistence_prediction, + obs=observation, + OLS=ols_prediction) - # save all forecasts locally - path = self.data_store.get("forecast_path", "general") - file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") - all_predictions.to_netcdf(file) + # save all forecasts locally + path = self.data_store.get("forecast_path", "general") + prefix = "forecasts_norm" if normalised else "forecasts" + file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc") + all_predictions.to_netcdf(file) def _get_frequency(self): getter = {"daily": "1D", "hourly": "1H"} return getter.get(self._sampling, None) @staticmethod - def _create_observation(data, _, mean, std, transformation_method): - return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method) + def _create_observation(data, _, mean, std, transformation_method, normalised): + obs = data.label.copy() + if not normalised: + obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) + return obs - def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method): + def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method, normalised): tmp_ols = self.ols_model.predict(input_data) - tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) + if not normalised: + tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) tmp_ols = np.expand_dims(tmp_ols, axis=1) target_shape = ols_prediction.values.shape ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method): + def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised): tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) - tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) + if not normalised: + tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) window_lead_time = self.data_store.get("window_lead_time", "general") persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)), axis=1) return persistence_prediction - def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method): + def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method, normalised): """ create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if @@ -192,7 +227,8 @@ class PostProcessing(RunEnvironment): :return: """ tmp_nn = self.model.predict(input_data) - tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) + if not normalised: + tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) if tmp_nn.ndim == 3: nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0) elif tmp_nn.ndim == 2: -- GitLab