diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index fa55d9d944eb03d6096eea7507045a1904360a1d..55a18a1a62b396fd8b4510416d8b056a5088e88a 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -26,8 +26,9 @@ DEFAULT_EPOCHS = 20 DEFAULT_TARGET_VAR = "o3" DEFAULT_TARGET_DIM = "variables" DEFAULT_WINDOW_LEAD_TIME = 3 -DEFAULT_DIMENSIONS = {"new_index": ["datetime", "Stations"]} DEFAULT_TIME_DIM = "datetime" +DEFAULT_ITER_DIM = "Stations" +DEFAULT_DIMENSIONS = {"new_index": [DEFAULT_TIME_DIM, DEFAULT_ITER_DIM]} DEFAULT_INTERPOLATION_METHOD = "linear" DEFAULT_INTERPOLATION_LIMIT = 1 DEFAULT_TRAIN_START = "1997-01-01" diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index a4f71582ccb842ba45690fcf6db054be44f0bdbd..78638a13b4ea50cd073ca4599a291342fad849d4 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -23,11 +23,14 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) - def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): + DEFAULT_FILTER_DIM = "filter" + + def __init__(self, *args, kz_filter_length, kz_filter_iter, filter_dim=DEFAULT_FILTER_DIM, **kwargs): self._check_sampling(**kwargs) # self.original_data = None # ToDo: implement here something to store unfiltered data self.kz_filter_length = to_list(kz_filter_length) self.kz_filter_iter = to_list(kz_filter_iter) + self.filter_dim = filter_dim self.cutoff_period = None self.cutoff_period_days = None super().__init__(*args, **kwargs) @@ -58,11 +61,11 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): @TimeTrackingWrapper def apply_kz_filter(self): """Apply kolmogorov zurbenko filter only on inputs.""" - kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim="datetime") + kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) filtered_data: List[xr.DataArray] = kz.run() self.cutoff_period = kz.period_null() self.cutoff_period_days = kz.period_null_days() - self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name="filter")) + self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) def create_filter_index(self) -> pd.Index: """ @@ -75,15 +78,15 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) index = list(map(f, index.tolist())) index = list(map(lambda x: str(x) + "d", index)) + ["res"] - return pd.Index(index, name="filter") + return pd.Index(index, name=self.filter_dim) def get_transposed_history(self) -> xr.DataArray: """Return history. - :return: history with dimensions datetime, window, Stations, variables. + :return: history with dimensions datetime, window, Stations, variables, filter. """ - return self.history.transpose("datetime", "window", "Stations", "variables", "filter").copy() - + return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim, + self.filter_dim).copy() class DataHandlerKzFilter(DefaultDataHandler): """Data handler using kz filtered data.""" diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index f49cee73c0736384bc878751c3c26e968af91147..caaa7a62d1b772808dcaf58abdfa5483e80861e7 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -204,7 +204,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) start, end = window, 1 res = [] - window_array = self.create_index_array('window', range(start, end), squeeze_dim=self.target_dim) + window_array = self.create_index_array(self.window_dim.range(start, end), squeeze_dim=self.target_dim) for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): res_filter = [] data_filter = data.sel({"filter": filter_name}) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index b176ccd6a3d9abf3d372c931b2182eaa3da95920..5c173eefa2577f535313c1b9180bfc132d1cc2e7 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -26,35 +26,36 @@ number = Union[float, int] num_or_list = Union[number, List[number]] data_or_none = Union[xr.DataArray, None] -# defaults -DEFAULT_STATION_TYPE = "background" -DEFAULT_NETWORK = "AIRBASE" -DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', - 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', - 'pblheight': 'maximum'} -DEFAULT_WINDOW_LEAD_TIME = 3 -DEFAULT_WINDOW_HISTORY_SIZE = 13 -DEFAULT_WINDOW_HISTORY_OFFSET = 0 -DEFAULT_TIME_DIM = "datetime" -DEFAULT_TARGET_VAR = "o3" -DEFAULT_TARGET_DIM = "variables" -DEFAULT_SAMPLING = "daily" -DEFAULT_INTERPOLATION_LIMIT = 0 -DEFAULT_INTERPOLATION_METHOD = "linear" - class DataHandlerSingleStation(AbstractDataHandler): + DEFAULT_STATION_TYPE = "background" + DEFAULT_NETWORK = "AIRBASE" + DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', + 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', + 'pblheight': 'maximum'} + DEFAULT_WINDOW_LEAD_TIME = 3 + DEFAULT_WINDOW_HISTORY_SIZE = 13 + DEFAULT_WINDOW_HISTORY_OFFSET = 0 + DEFAULT_TIME_DIM = "datetime" + DEFAULT_TARGET_VAR = "o3" + DEFAULT_TARGET_DIM = "variables" + DEFAULT_ITER_DIM = "Stations" + DEFAULT_WINDOW_DIM = "window" + DEFAULT_SAMPLING = "daily" + DEFAULT_INTERPOLATION_LIMIT = 0 + DEFAULT_INTERPOLATION_METHOD = "linear" def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, + iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, window_history_size=DEFAULT_WINDOW_HISTORY_SIZE, window_history_offset=DEFAULT_WINDOW_HISTORY_OFFSET, window_lead_time=DEFAULT_WINDOW_LEAD_TIME, interpolation_limit: Union[int, Tuple[int]] = DEFAULT_INTERPOLATION_LIMIT, interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, **kwargs): - super().__init__() # path, station, statistics_per_var, transformation, **kwargs) + super().__init__() self.station = helpers.to_list(station) self.path = self.setup_data_path(data_path, sampling) self.statistics_per_var = statistics_per_var @@ -69,6 +70,8 @@ class DataHandlerSingleStation(AbstractDataHandler): self.target_dim = target_dim self.target_var = target_var self.time_dim = time_dim + self.iter_dim = iter_dim + self.window_dim = window_dim self.window_history_size = window_history_size self.window_history_offset = window_history_offset self.window_lead_time = window_lead_time @@ -118,16 +121,14 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: history with dimensions datetime, window, Stations, variables. """ - return self.history.transpose("datetime", "window", "Stations", - "variables").copy() # ToDo: remove hardcoded dims + return self.history.transpose(self.time_dim, self.window_dim, self.iter_dim, self.target_dim).copy() def get_transposed_label(self) -> xr.DataArray: """Return label. :return: label with dimensions datetime*, window*, Stations, variables. """ - return self.label.squeeze(["Stations", "variables"]).transpose("datetime", - "window").copy() # ToDo: remove hardcoded dims + return self.label.squeeze([self.iter_dim, self.target_dim]).transpose(self.time_dim, self.window_dim).copy() def get_X(self, **kwargs): return self.get_transposed_history() @@ -142,13 +143,14 @@ class DataHandlerSingleStation(AbstractDataHandler): def call_transform(self, inverse=False): opts_input = self._transformation[0] self.input_data, opts_input = self.transform(self.input_data, dim=self.time_dim, inverse=inverse, - opts=opts_input) + opts=opts_input, transformation_dim=self.target_dim) opts_target = self._transformation[1] self.target_data, opts_target = self.transform(self.target_data, dim=self.time_dim, inverse=inverse, - opts=opts_target) + opts=opts_target, transformation_dim=self.target_dim) self._transformation = (opts_input, opts_target) - def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None): + def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, + transformation_dim=DEFAULT_TARGET_DIM): """ Transform data according to given transformation settings. @@ -161,17 +163,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :param string/int dim: This param is not used for inverse transformation. | for xarray.DataArray as string: name of dimension which should be standardised | for pandas.DataFrame as int: axis of dimension which should be standardised - :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented - yet. This param is not used for inverse transformation. :param inverse: Switch between transformation and inverse transformation. - :param mean: Used for transformation (if required by 'method') based on external data. If 'None' the mean is - calculated over the data in this class instance. - :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is - calculated over the data in this class instance. - :param min: Used for transformation (if required by 'method') based on external data. If 'None' min_val is - extracted from the data in this class instance. - :param max: Used for transformation (if required by 'method') based on external data. If 'None' max_val is - extracted from the data in this class instance. :return: xarray.DataArrays or pandas.DataFrames: #. mean: Mean of data @@ -203,7 +195,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if not inverse: transformed_values = [] for var in data_in.variables.values: - data_var = data_in.sel(variables=[var]) # ToDo: replace hardcoded variables dim + data_var = data_in.sel(**{transformation_dim: [var]}) var_opts = opts.get(var, {}) _method = var_opts.get("method", "standardise") _mean = var_opts.get("mean", None) @@ -211,9 +203,9 @@ class DataHandlerSingleStation(AbstractDataHandler): mean, std, values = locals()["f" if _mean is None else "f_apply"](data_var, _method, _mean, _std) opts_updated[var] = {"method": _method, "mean": mean, "std": std} transformed_values.append(values) - return xr.concat(transformed_values, dim="variables"), opts_updated # ToDo: replace hardcoded variables dim + return xr.concat(transformed_values, dim=transformation_dim), opts_updated else: - return self.inverse_transform(data_in, opts) # ToDo: add return statement + return self.inverse_transform(data_in, opts, transformation_dim) @TimeTrackingWrapper def setup_samples(self): @@ -286,7 +278,8 @@ class DataHandlerSingleStation(AbstractDataHandler): @staticmethod def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling, - station_type=None, network=None, store_data_locally=True, data_origin: Dict = None) \ + station_type=None, network=None, store_data_locally=True, data_origin: Dict = None, + time_dim=DEFAULT_TIME_DIM, target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM) \ -> [xr.DataArray, pd.DataFrame]: """ Download data from TOAR database using the JOIN interface. @@ -304,8 +297,8 @@ class DataHandlerSingleStation(AbstractDataHandler): network_name=network, sampling=sampling, data_origin=data_origin) df_all[station[0]] = df # convert df_all to xarray - xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} - xarr = xr.Dataset(xarr).to_array(dim='Stations') + xarr = {k: xr.DataArray(v, dims=[time_dim, target_dim]) for k, v in df_all.items()} + xarr = xr.Dataset(xarr).to_array(dim=iter_dim) if store_data_locally is True: # save locally as nc/csv file xarr.to_netcdf(path=file_name) @@ -313,7 +306,8 @@ class DataHandlerSingleStation(AbstractDataHandler): return xarr, meta def download_data(self, *args, **kwargs): - data, meta = self.download_data_from_join(*args, **kwargs) + data, meta = self.download_data_from_join(*args, **kwargs, time_dim=self.time_dim, target_dim=self.target_dim, + iter_dim=self.iter_dim) return data, meta @staticmethod @@ -380,7 +374,7 @@ class DataHandlerSingleStation(AbstractDataHandler): _range = list(map(lambda x: x + offset, range(start, end))) for w in _range: res.append(data.shift({dim: -w})) - window_array = self.create_index_array('window', _range, squeeze_dim=self.target_dim) + window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) res = xr.concat(res, dim=window_array) return res @@ -593,7 +587,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if len(msg) > 0: raise AttributeError(f"Inverse transform {method} can not be executed because following is None: {msg}") - def inverse_transform(self, data_in, opts) -> xr.DataArray: + def inverse_transform(self, data_in, opts, transformation_dim) -> xr.DataArray: """ Perform inverse transformation. @@ -614,16 +608,15 @@ class DataHandlerSingleStation(AbstractDataHandler): raise NotImplementedError transformed_values = [] - sel_dim = "variables" # ToDo: replace hardcoded variables dim squeeze = False - if sel_dim in data_in.coords: - if sel_dim not in data_in.dims: - data_in = data_in.expand_dims(sel_dim) + if transformation_dim in data_in.coords: + if transformation_dim not in data_in.dims: + data_in = data_in.expand_dims(transformation_dim) squeeze = True else: - raise IndexError(f"Could not find given dimension: {sel_dim}. Available is: {data_in.coords}") + raise IndexError(f"Could not find given dimension: {transformation_dim}. Available is: {data_in.coords}") for var in data_in.variables.values: - data_var = data_in.sel(**{sel_dim: [var]}) + data_var = data_in.sel(**{transformation_dim: [var]}) var_opts = opts.get(var, {}) _method = var_opts.get("method", None) if _method is None: @@ -633,8 +626,8 @@ class DataHandlerSingleStation(AbstractDataHandler): self.check_inverse_transform_params(_method, _mean, _std) values = f_inverse(data_var, _method, _mean, _std) transformed_values.append(values) - res = xr.concat(transformed_values, dim=sel_dim) - return res.squeeze(sel_dim) if squeeze else res + res = xr.concat(transformed_values, dim=transformation_dim) + return res.squeeze(transformation_dim) if squeeze else res def apply_transformation(self, data, base=None, dim=0, inverse=False): """ @@ -655,7 +648,8 @@ class DataHandlerSingleStation(AbstractDataHandler): raise ValueError("apply transformation requires a reference for transformation options. Please specify if" "you want to use input or target transformation using the parameter 'base'. Given was: " + base) - return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse) + return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse, + transformation_dim=self.target_dim) if __name__ == "__main__": diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index af3e64f48d2c3c40cf536d848453659a277de80a..8914969ac683f01f3d5f2e833bb870b5c710f188 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -31,12 +31,16 @@ class DefaultDataHandler(AbstractDataHandler): _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) + DEFAULT_ITER_DIM = "Stations" + DEFAULT_TIME_DIM = "datetime" + def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0, extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None, - store_processed_data=True): + store_processed_data=True, iter_dim=DEFAULT_ITER_DIM, time_dim=DEFAULT_TIME_DIM): super().__init__() self.id_class = id_class - self.interpolation_dim = "datetime" + self.time_dim = time_dim + self.iter_dim = iter_dim self.min_length = min_length self._X = None self._Y = None @@ -46,7 +50,7 @@ class DefaultDataHandler(AbstractDataHandler): self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle") self._collection = self._create_collection() self.harmonise_X() - self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim) + self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.time_dim) self._store(fresh_store=True, store_processed_data=store_processed_data) @classmethod @@ -133,7 +137,7 @@ class DefaultDataHandler(AbstractDataHandler): def harmonise_X(self): X_original, Y_original = self.get_X_original(), self.get_Y_original() - dim = self.interpolation_dim + dim = self.time_dim intersect = reduce(np.intersect1d, map(lambda x: x.coords[dim].values, X_original)) if len(intersect) < max(self.min_length, 1): X, Y = None, None @@ -149,7 +153,7 @@ class DefaultDataHandler(AbstractDataHandler): return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse) def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, - timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"): + timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM): """ Multiply extremes. @@ -281,14 +285,15 @@ class DefaultDataHandler(AbstractDataHandler): _inner() # aggregate all information + iter_dim = sp_keys.get("iter_dim", cls.DEFAULT_ITER_DIM) pop_list = [] for i, transformation in enumerate(transformation_dict): for k in transformation.keys(): try: if transformation[k]["mean"] is not None: - transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean("Stations") + transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean(iter_dim) if transformation[k]["std"] is not None: - transformation_dict[i][k]["std"] = transformation[k]["std"].mean("Stations") + transformation_dict[i][k]["std"] = transformation[k]["std"].mean(iter_dim) except KeyError: pop_list.append((i, k)) for (i, k) in pop_list: diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index b1426ed289783bcd2fe7939d280aa20d6e703860..7fe6bc2fe82f50cb011f325707f102d51549c174 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1155,11 +1155,12 @@ class PlotAvailability(AbstractPlotClass): """ def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", - summary_name="data availability", time_dimension="datetime"): + summary_name="data availability", time_dimension="datetime", window_dimension="window"): """Initialise.""" # create standard Gantt plot for all stations (currently in single pdf file with single page) super().__init__(plot_folder, "data_availability") - self.dim = time_dimension + self.time_dim = time_dimension + self.window_dim = window_dimension self.sampling = self._get_sampling(sampling) self.linewidth = None if self.sampling == 'h': @@ -1182,11 +1183,11 @@ class PlotAvailability(AbstractPlotClass): plt_dict = {} for subset, data_collection in generators.items(): for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(window=1).notnull() - group = (labels_bool != labels_bool.shift({self.dim: 1})).cumsum() + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() + group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, - index=labels.coords[self.dim].values) + index=labels.coords[self.time_dim].values) t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) t2 = [i[1:] for i in t if i[0]] @@ -1201,8 +1202,8 @@ class PlotAvailability(AbstractPlotClass): for subset, data_collection in generators.items(): all_data = None for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(window=1).notnull() + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() if all_data is None: all_data = labels_bool else: @@ -1210,9 +1211,9 @@ class PlotAvailability(AbstractPlotClass): all_data = np.logical_or(tmp, labels_bool).combine_first( all_data) # apply logical on merge and fill missing with all_data - group = (all_data != all_data.shift({self.dim: 1})).cumsum() + group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, - index=all_data.coords[self.dim].values) + index=all_data.coords[self.time_dim].values) t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) t2 = [i[1:] for i in t if i[0]] if plt_dict.get(summary_name) is None: @@ -1251,11 +1252,16 @@ class PlotAvailability(AbstractPlotClass): @TimeTrackingWrapper class PlotSeparationOfScales(AbstractPlotClass): - def __init__(self, collection: DataCollection, plot_folder: str = "."): + def __init__(self, collection: DataCollection, plot_folder: str = ".", time_dim="datetime", window_dim="window", + filter_dim="filter", target_dim="variables"): """Initialise.""" # create standard Gantt plot for all stations (currently in single pdf file with single page) plot_folder = os.path.join(plot_folder, "separation_of_scales") super().__init__(plot_folder, "separation_of_scales") + self.time_dim = time_dim + self.window_dim = window_dim + self.filter_dim = filter_dim + self.target_dim = target_dim self._plot(collection) def _plot(self, collection: DataCollection): @@ -1265,7 +1271,7 @@ class PlotSeparationOfScales(AbstractPlotClass): station = dh.id_class.station[0] data = data.sel(Stations=station) # plt.subplots() - data.plot(x="datetime", y="window", col="filter", row="variables", robust=True) + data.plot(x=self.time_dim, y=self.window_dim, col=self.filter_dim, row=self.target_dim, robust=True) self.plot_name = f"{orig_plot_name}_{station}" self._save() diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index a27607429148fe034d9d5f2f17c3d2caf6d3a22f..18ef98f81d29730d3d6fac9bb57e6dc91942df24 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -17,7 +17,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \ DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ - DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN + DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.model_class import MyLittleModel as VanillaModel @@ -213,6 +213,7 @@ class ExperimentSetup(RunEnvironment): window_lead_time: int = None, dimensions=None, time_dim=None, + iter_dim=None, interpolation_method=None, interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, @@ -298,6 +299,9 @@ class ExperimentSetup(RunEnvironment): self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) + # iteration + self._set_param("iter_dim", iter_dim, default=DEFAULT_ITER_DIM) + # target self._set_param("target_var", target_var, default=DEFAULT_TARGET_VAR) self._set_param("target_dim", target_dim, default=DEFAULT_TARGET_DIM) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index f1fd1d533c4c62012393a0115db17fbeb1bae017..d223858ccf056703b1e4c0975a382f11f6a150cc 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -293,12 +293,17 @@ class PostProcessing(RunEnvironment): path = self.data_store.get("forecast_path") plot_list = self.data_store.get("plot_list", "postprocessing") - time_dimension = self.data_store.get("time_dim") + time_dim = self.data_store.get("time_dim") + window_dim = self.data_store.get("window_dim") + target_dim = self.data_store.get("target_dim") + iter_dim = self.data_store.get("iter_dim") try: if ("filter" in self.test_data[0].get_X(as_numpy=False)[0].coords) and ( "PlotSeparationOfScales" in plot_list): - PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path) + filter_dim = self.data_store.get("filter_dim", None) + PlotSeparationOfScales(self.test_data, plot_folder=self.plot_path, time_dim=time_dim, + window_dim=window_dim, target_dim=target_dim, **{"filter_dim": filter_dim}) except Exception as e: logging.error(f"Could not create plot PlotSeparationOfScales due to the following error: {e}") @@ -365,14 +370,16 @@ class PostProcessing(RunEnvironment): try: if "PlotAvailability" in plot_list: avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dimension) + PlotAvailability(avail_data, plot_folder=self.plot_path, time_dimension=time_dim, + window_dimension=window_dim) except Exception as e: logging.error(f"Could not create plot PlotAvailability due to the following error: {e}") try: if "PlotAvailabilityHistogram" in plot_list: avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} - PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, ) # time_dimension=time_dimension) + PlotAvailabilityHistogram(avail_data, plot_folder=self.plot_path, station_dim=iter_dim, + history_dim=window_dim) except Exception as e: logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}") @@ -400,6 +407,7 @@ class PostProcessing(RunEnvironment): """ logging.debug("start make_prediction") time_dimension = self.data_store.get("time_dim") + window_dim = self.data_store.get("window_dim") for i, data in enumerate(self.test_data): input_data = data.get_X() target_data = data.get_Y(as_numpy=False) @@ -432,7 +440,7 @@ class PostProcessing(RunEnvironment): "persi": persistence_prediction, "obs": observation, "ols": ols_prediction} - all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']), + all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), time_dimension, **prediction_dict) # save all forecasts locally @@ -441,57 +449,6 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"{prefix}_{str(data)}_test.nc") all_predictions.to_netcdf(file) - def make_prediction_old(self): - """ - Create predictions for NN, OLS, and persistence and add true observation as reference. - - Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All - predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can - be found inside `forecast_path`. - """ - logging.debug("start make_prediction") - time_dimension = self.data_store.get("time_dim") - for i, data in enumerate(self.test_data): - input_data = data.get_X() - target_data = data.get_Y(as_numpy=False) - observation_data = data.get_observation() - - # get scaling parameters - transformation_opts = data.get_transformation_Y() - - for normalised in [True, False]: - # create empty arrays - nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( - target_data, count=4) - - # nn forecast - nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_opts, normalised) - - # persistence - persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, - transformation_opts, normalised) - - # ols - ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_opts, normalised) - - # observation - observation = self._create_observation(target_data, observation, transformation_opts, normalised) - - # merge all predictions - full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency()) - prediction_dict = {self.forecast_indicator: nn_prediction, - "persi": persistence_prediction, - "obs": observation, - "ols": ols_prediction} - all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']), - time_dimension, **prediction_dict) - - # save all forecasts locally - path = self.data_store.get("forecast_path") - prefix = "forecasts_norm" if normalised else "forecasts" - file = os.path.join(path, f"{prefix}_{str(data)}_test.nc") - all_predictions.to_netcdf(file) - def _get_frequency(self) -> str: """Get frequency abbreviation.""" getter = {"daily": "1D", "hourly": "1H"}