Skip to content
Snippets Groups Projects
Select Git revision
  • aa30af07037a3b17bd397948a5ac0c9dae926ce9
  • 2022 default
  • 2021
  • master protected
  • 2021
5 results

parallel_search-serial.f90

Blame
  • post_processing.py 12.04 KiB
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-12-11'
    
    
    import logging
    import os
    
    import numpy as np
    import pandas as pd
    import xarray as xr
    import keras
    
    from src.run_modules.run_environment import RunEnvironment
    from src.data_handling.data_distributor import Distributor
    from src.data_handling.data_generator import DataGenerator
    from src.model_modules.linear_model import OrdinaryLeastSquaredModel
    from src import statistics
    from src.plotting.postprocessing_plotting import plot_conditional_quantiles
    from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, PlotCompetitiveSkillScore
    from src.datastore import NameNotFoundInDataStore
    
    
    class PostProcessing(RunEnvironment):
    
        def __init__(self):
            super().__init__()
            self.model: keras.Model = self._load_model()
            self.ols_model = None
            self.batch_size: int = self.data_store.get("batch_size", "general.model")
            self.test_data: DataGenerator = self.data_store.get("generator", "general.test")
            self.test_data_distributed = Distributor(self.test_data, self.model, self.batch_size)
            self.train_data: DataGenerator = self.data_store.get("generator", "general.train")
            self.train_val_data: DataGenerator = self.data_store.get("generator", "general.train_val")
            self.plot_path: str = self.data_store.get("plot_path", "general")
            self.skill_scores = None
            self._run()
    
        def _run(self):
            self.train_ols_model()
            preds_for_all_stations = self.make_prediction()
            self.skill_scores = self.calculate_skill_scores()
            self.plot()
    
        def _load_model(self):
            try:
                model = self.data_store.get("best_model", "general")
            except NameNotFoundInDataStore:
                logging.info("no model saved in data store. trying to load model from experiment")
                path = self.data_store.get("experiment_path", "general")
                name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5"
                model_name = os.path.join(path, name)
                model = keras.models.load_model(model_name)
            return model
    
        def plot(self):
            logging.debug("Run plotting routines...")
            path = self.data_store.get("forecast_path", "general")
            target_var = self.data_store.get("target_var", "general")
    
            plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="orig",
                                       forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path)
            plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN",
                                       forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path)
            PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
            PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var,
                               plot_folder=self.plot_path)
            PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN")
            PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False,
                                         extra_name_tag="all_terms_", model_setup="CNN")
            PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN")
    
        def calculate_test_score(self):
            test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
                                                       use_multiprocessing=False, verbose=0, steps=1)
            logging.info(f"test score = {test_score}")
            self._save_test_score(test_score)
    
        def _save_test_score(self, score):
            path = self.data_store.get("experiment_path", "general")
            with open(os.path.join(path, "test_scores.txt")) as f:
                for index, item in enumerate(score):
                    f.write(f"{self.model.metrics[index]}, {item}\n")
    
        def train_ols_model(self):
            self.ols_model = OrdinaryLeastSquaredModel(self.train_data)
    
        def make_prediction(self, freq="1D"):
            logging.debug("start make_prediction")
            nn_prediction_all_stations = []
            for i, v in enumerate(self.test_data):
                data = self.test_data.get_data_generator(i)
    
                nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3)
                input_data = self.test_data[i][0]
    
                # get scaling parameters
                mean, std, transformation_method = data.get_transformation_information(variable='o3')
    
                # nn forecast
                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method)
    
                # persistence
                persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, 
                                                                           transformation_method)
    
                # ols
                ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method)
    
                # orig pred
                orig_pred = self._create_orig_forecast(data, None, mean, std, transformation_method)
    
                # merge all predictions
                full_index = self.create_fullindex(data.data.indexes['datetime'], freq)
                all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
                                                              CNN=nn_prediction,
                                                              persi=persistence_prediction,
                                                              orig=orig_pred,
                                                              OLS=ols_prediction)
    
                # save all forecasts locally
                path = self.data_store.get("forecast_path", "general")
                file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc")
                all_predictions.to_netcdf(file)
    
                # save nn forecast to return variable
                nn_prediction_all_stations.append(nn_prediction)
            return nn_prediction_all_stations
    
        @staticmethod
        def _create_orig_forecast(data, _, mean, std, transformation_method):
            return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method)
    
        def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method):
            tmp_ols = self.ols_model.predict(input_data)
            tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method)
            ols_prediction.values = np.swapaxes(np.expand_dims(tmp_ols, axis=1), 2, 0)
            return ols_prediction
    
        def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method):
            tmp_persi = input_data.sel({'window': 0, 'variables': 'o3'})
            tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
            window_lead_time = self.data_store.get("window_lead_time", "general")
            persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
                                                           axis=1)
            return persistence_prediction
    
        def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method):
            """
            create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output
            in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if
            the network has multiple output branches). The main branch is defined to be the last entry of all outputs.
            :param input_data:
            :param nn_prediction:
            :param mean:
            :param std:
            :param transformation_method:
            :return:
            """
            tmp_nn = self.model.predict(input_data)
            tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
            if tmp_nn.ndim == 3:
                nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
            elif tmp_nn.ndim == 2:
                nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
            else:
                raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
            return nn_prediction
    
        @staticmethod
        def _create_empty_prediction_arrays(generator, count=1):
            return [generator.label.copy() for _ in range(count)]
    
        @staticmethod
        def create_fullindex(df, freq):
            # Diese Funkton erstellt ein leeres df, mit Index der Frequenz frequ zwischen dem ersten und dem letzten Datum in df
            # param: df as pandas dataframe
            # param: freq as string
            # return: index as pandas dataframe
            if isinstance(df, pd.DataFrame):
                earliest = df.index[0]
                latest = df.index[-1]
            elif isinstance(df, xr.DataArray):
                earliest = df.index[0].values
                latest = df.index[-1].values
            elif isinstance(df, pd.DatetimeIndex):
                earliest = df[0]
                latest = df[-1]
            else:
                raise AttributeError(f"unknown array type. Only pandas dataframes, xarray dataarrays and pandas datetimes "
                                     f"are supported. Given type is {type(df)}.")
            index = pd.DataFrame(index=pd.date_range(earliest, latest, freq=freq))
            return index
    
        @staticmethod
        def create_forecast_arrays(index, ahead_names, **kwargs):
            """
            This function combines different forecast types into one xarray.
    
            :param index: as index; index for forecasts (e.g. time)
            :param ahead_names: as list of str/int: names of ahead values (e.g. hours or days)
            :param kwargs: as xarrays; data of forecasts
            :return: xarray of dimension 3: index, ahead_names, # predictions
    
            """
            keys = list(kwargs.keys())
            res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
                               coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
            for k, v in kwargs.items():
                try:
                    match_index = np.stack(set(res.index.values) & set(v.index.values))
                    res.loc[match_index, :, k] = v.loc[match_index]
                except AttributeError:  # v is xarray type and has no attribute .index
                    match_index = np.stack(set(res.index.values) & set(v.indexes['datetime'].values))
                    res.loc[match_index, :, k] = v.sel({'datetime': match_index}).squeeze('Stations').transpose()
            return res
    
        def _get_external_data(self, station):
            try:
                data = self.train_val_data.get_data_generator(station)
                mean, std, transformation_method = data.get_transformation_information(variable='o3')
                external_data = self._create_orig_forecast(data, None, mean, std, transformation_method)
                external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
                return external_data.rename({'datetime': 'index'})
            except KeyError:
                return None
    
        def calculate_skill_scores(self):
            path = self.data_store.get("forecast_path", "general")
            window_lead_time = self.data_store.get("window_lead_time", "general")
            skill_score_competitive = {}
            skill_score_climatological = {}
            for station in self.test_data.stations:
                file = os.path.join(path, f"forecasts_{station}_test.nc")
                data = xr.open_dataarray(file)
                skill_score = statistics.SkillScores(data)
                external_data = self._get_external_data(station)
                skill_score_competitive[station] = skill_score.skill_scores(window_lead_time)
                skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data,
                                                                                              window_lead_time)
            return skill_score_competitive, skill_score_climatological