From a8cc85f185cee28b5a27691a69faf9aa6d3ae545 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 28 Jul 2020 14:27:40 +0200 Subject: [PATCH] MLAir is now independent from the window_lead_time_parameter (is extracted from Y shape) --- src/data_handling/advanced_data_handling.py | 8 ++-- src/data_handling/bootstraps.py | 31 +-------------- src/helpers/__init__.py | 2 +- src/helpers/helpers.py | 7 ++++ src/plotting/postprocessing_plotting.py | 5 +-- src/run_modules/model_setup.py | 2 - src/run_modules/post_processing.py | 18 ++++----- src/run_modules/pre_processing.py | 44 --------------------- 8 files changed, 24 insertions(+), 93 deletions(-) diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py index 63d26bd9..ffd6b235 100644 --- a/src/data_handling/advanced_data_handling.py +++ b/src/data_handling/advanced_data_handling.py @@ -13,7 +13,7 @@ import datetime as dt import shutil import inspect -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Dict import logging from functools import reduce from src.data_handling.station_preparation import StationPrep @@ -71,7 +71,7 @@ class AbstractDataPreparation: @classmethod def transformation(cls, *args, **kwargs): - raise NotImplementedError + return None def get_X(self, upsampling=False, as_numpy=False): raise NotImplementedError @@ -82,8 +82,8 @@ class AbstractDataPreparation: def get_data(self, upsampling=False, as_numpy=False): return self.get_X(upsampling, as_numpy), self.get_Y(upsampling, as_numpy) - def get_coordinates(self): - return None, None + def get_coordinates(self) -> Union[None, Dict]: + return None class DefaultDataPreparation(AbstractDataPreparation): diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 4aaa3cba..d026d551 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -17,7 +17,6 @@ import os from collections import Iterator, Iterable from itertools import chain -import dask.array as da import numpy as np import xarray as xr @@ -69,13 +68,12 @@ class BootstrapIterator(Iterator): return d.values @staticmethod - def shuffle(data: da.array) -> da.core.Array: + def shuffle(data: np.ndarray) -> np.ndarray: """ Shuffle randomly from given data (draw elements with replacement). :param data: data to shuffle - :param chunks: chunk size for dask - :return: shuffled data as dask core array (not computed yet) + :return: shuffled data as numpy array """ size = data.shape return np.random.choice(data.reshape(-1, ), size=size) @@ -131,28 +129,3 @@ class BootStraps(Iterable): prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze() vals = np.tile(prediction.data, (self.number_of_bootstraps, 1)) return vals[~np.isnan(vals).any(axis=1), :] - - - - -if __name__ == "__main__": - - from src.run_modules.experiment_setup import ExperimentSetup - from src.run_modules.run_environment import RunEnvironment - from src.run_modules.pre_processing import PreProcessing - - formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' - logging.basicConfig(format=formatter, level=logging.INFO) - - with RunEnvironment() as run_env: - ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013'], - station_type='background', trainable=True, window_history_size=9) - PreProcessing() - - data = run_env.data_store.get("generator", "general.test") - number_bootstraps = 10 - - boots = BootStraps(data, number_bootstraps) - for b in boots.boot_strap_generator(): - a, c = b - logging.info(f"len is {len(boots.get_boot_strap_meta())}") diff --git a/src/helpers/__init__.py b/src/helpers/__init__.py index 546713b3..9e2f612c 100644 --- a/src/helpers/__init__.py +++ b/src/helpers/__init__.py @@ -3,4 +3,4 @@ from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value diff --git a/src/helpers/helpers.py b/src/helpers/helpers.py index 968ee538..b12d9028 100644 --- a/src/helpers/helpers.py +++ b/src/helpers/helpers.py @@ -92,3 +92,10 @@ def remove_items(obj: Union[List, Dict], items: Any): return remove_from_dict(obj, items) else: raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") + + +def extract_value(encapsulated_value): + try: + return extract_value(encapsulated_value[0]) + except TypeError: + return encapsulated_value diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 7e282022..2fe71e23 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -19,7 +19,6 @@ import xarray as xr from matplotlib.backends.backend_pdf import PdfPages from src import helpers -from src.data_handling import DataGenerator from src.data_handling.iterator import DataCollection from src.helpers import TimeTrackingWrapper @@ -881,7 +880,7 @@ class PlotAvailability(AbstractPlotClass): """ - def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily", + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", summary_name="data availability", time_dimension="datetime"): """Initialise.""" # create standard Gantt plot for all stations (currently in single pdf file with single page) @@ -927,7 +926,7 @@ class PlotAvailability(AbstractPlotClass): plt_dict[str(station)].update({subset: t2}) return plt_dict - def _summarise_data(self, generators: Dict[str, DataGenerator], summary_name: str): + def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): plt_dict = {} for subset, data_collection in generators.items(): all_data = None diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 7de1c7b6..ea6199c9 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -31,8 +31,6 @@ class ModelSetup(RunEnvironment): * `trainable` [.] * `create_new_model` [.] * `generator` [train] - * `window_lead_time` [.] - * `window_history_size` [.] * `model_class` [.] Optional objects diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index f63e92ba..4b32ecb7 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -15,7 +15,7 @@ import xarray as xr from src.data_handling import BootStraps, KerasIterator from src.helpers.datastore import NameNotFoundInDataStore -from src.helpers import TimeTracking, statistics +from src.helpers import TimeTracking, statistics, extract_value from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.model_class import AbstractModelClass from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ @@ -42,7 +42,7 @@ class PostProcessing(RunEnvironment): * `model_path` [.] * `target_var` [.] * `sampling` [.] - * `window_lead_time` [.] + * `output_shape` [model] * `evaluate_bootstraps` [postprocessing] and if enabled: * `create_new_bootstraps` [postprocessing] @@ -74,6 +74,7 @@ class PostProcessing(RunEnvironment): self.plot_path: str = self.data_store.get("plot_path") self.target_var = self.data_store.get("target_var") self._sampling = self.data_store.get("sampling") + self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.skill_scores = None self.bootstrap_skill_scores = None self._run() @@ -182,7 +183,6 @@ class PostProcessing(RunEnvironment): # extract all requirements from data store bootstrap_path = self.data_store.get("bootstrap_path") forecast_path = self.data_store.get("forecast_path") - window_lead_time = self.data_store.get("window_lead_time") number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") forecast_file = f"forecasts_norm_%s_test.nc" bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps() @@ -203,14 +203,14 @@ class PostProcessing(RunEnvironment): orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"]) # calculate skill scores for each variable - skill = pd.DataFrame(columns=range(1, window_lead_time + 1)) + skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1)) for boot_set in bootstraps: boot_var = f"{boot_set[0]}_{boot_set[1]}" file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc") boot_data = xr.open_dataarray(file_name) boot_data = boot_data.combine_first(labels).combine_first(orig) boot_scores = [] - for ahead in range(1, window_lead_time + 1): + for ahead in range(1, self.window_lead_time + 1): data = boot_data.sel(ahead=ahead) boot_scores.append( skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig")) @@ -429,8 +429,7 @@ class PostProcessing(RunEnvironment): tmp_persi = data.copy() if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) - window_lead_time = self.data_store.get("window_lead_time") - persistence_prediction.values = np.tile(tmp_persi, (window_lead_time, 1)).T + persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T return persistence_prediction def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray, @@ -547,7 +546,6 @@ class PostProcessing(RunEnvironment): :return: competitive and climatological skill scores """ path = self.data_store.get("forecast_path") - window_lead_time = self.data_store.get("window_lead_time") skill_score_competitive = {} skill_score_climatological = {} for station in self.test_data: @@ -555,7 +553,7 @@ class PostProcessing(RunEnvironment): data = xr.open_dataarray(file) skill_score = statistics.SkillScores(data) external_data = self._get_external_data(station) - skill_score_competitive[station] = skill_score.skill_scores(window_lead_time) + skill_score_competitive[station] = skill_score.skill_scores(self.window_lead_time) skill_score_climatological[station] = skill_score.climatological_skill_scores(external_data, - window_lead_time) + self.window_lead_time) return skill_score_competitive, skill_score_climatological diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 1b7124d0..2e78887f 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -10,7 +10,6 @@ from typing import Tuple, Dict, List import numpy as np import pandas as pd -from src.data_handling import DataGenerator from src.data_handling import DataCollection from src.helpers import TimeTracking from src.configuration import path_config @@ -196,49 +195,6 @@ class PreProcessing(RunEnvironment): self.data_store.set("stations", valid_stations, scope=set_name) self.data_store.set("data_collection", collection, scope=set_name) - @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True, - name=None): - """ - Check if all given stations in `all_stations` are valid. - - Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the - loading time are logged in debug mode. - - :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`, - `variables`, `interpolate_dim`, `target_dim`, `target_var`). - :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`, - `window_lead_time`). - :param all_stations: All stations to check. - :param name: name to display in the logging info message - - :return: Corrected list containing only valid station IDs. - """ - t_outer = TimeTracking() - t_inner = TimeTracking(start=False) - logging.info(f"check valid stations started{' (%s)' % name if name else ''}") - valid_stations = [] - - # all required arguments of the DataGenerator can be found in args, positional arguments in args and kwargs - data_gen = DataGenerator(**args, **kwargs) - for pos, station in enumerate(all_stations): - t_inner.run() - logging.info(f"check station {station} ({pos + 1} / {len(all_stations)})") - try: - data = data_gen.get_data_generator(key=station, load_local_tmp_storage=load_tmp, - save_local_tmp_storage=save_tmp) - if data.history is None: - raise AttributeError - valid_stations.append(station) - logging.debug( - f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') - logging.debug(f"{station}: loading time = {t_inner}") - except (AttributeError, EmptyQueryResult): - continue - logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/" - f"{len(all_stations)} valid stations.") - return valid_stations - def validate_station(self, data_preparation, set_stations, set_name=None, overwrite_local_data=False): """ Check if all given stations in `all_stations` are valid. -- GitLab