diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index ec0f1f73282979a1e69945e1ad7f6817bdf3ba12..c1634e50e5515f08af5f9e49ecb037c836811c7b 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -23,6 +23,7 @@ from mlair import helpers from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict from mlair.data_handler.abstract_data_handler import AbstractDataHandler from mlair.helpers import data_sources +from mlair.helpers.data_sources.toar_data import correct_stat_name # define a more general date type for type hinting date = Union[dt.date, dt.datetime] @@ -83,11 +84,10 @@ class DataHandlerSingleStation(AbstractDataHandler): self.data_origin = data_origin self.do_transformation = transformation is not None self.input_data, self.target_data = None, None - self._transformation = self.setup_transformation(transformation) self.sampling = sampling self.target_dim = target_dim - self.target_var = target_var + self._target_var = target_var self.time_dim = time_dim self.iter_dim = iter_dim self.window_dim = window_dim @@ -109,11 +109,23 @@ class DataHandlerSingleStation(AbstractDataHandler): # internal self._data: xr.DataArray = None # loaded raw data self.meta = None - self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables + # self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables + if variables is None: + self.variables = sorted( + data_sources.get_vars_with_stat_name(statistics_per_var, variables=None) + ) + + else: + self.variables = sorted( + data_sources.get_vars_with_stat_name(statistics_per_var, variables=variables) + ) + self.history = None self.label = None self.observation = None + self._transformation = self.setup_transformation(transformation) + # create samples self.setup_samples() self.clean_up() @@ -308,14 +320,25 @@ class DataHandlerSingleStation(AbstractDataHandler): def set_inputs_and_targets(self): inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) targets = self._data.sel( - {self.target_dim: helpers.to_list(self.target_var)}) # ToDo: is it right to expand this dim?? + {self.target_dim: helpers.to_list(self.target_var) + }) # ToDo: is it right to expand this dim?? self.input_data = inputs self.target_data = targets + @property + def target_var(self): + """ + Combine target var and corresponding first statistics to create combined targetvar name + :return: + :rtype: + """ + t_var = helpers.data_sources.get_target_var_with_stat_name(self.statistics_per_var, self._target_var) + return t_var + def make_samples(self): self.make_history_window(self.target_dim, self.window_history_size, self.time_dim) #todo stopped here - self.make_labels(self.target_dim, self.target_var, self.time_dim, self.window_lead_time) - self.make_observation(self.target_dim, self.target_var, self.time_dim) + self.make_labels(self.time_dim, self.window_lead_time) + self.make_observation(self.time_dim) self.remove_nan(self.time_dim) def load_data(self, path, station, statistics_per_var, sampling, store_data_locally=False, @@ -449,7 +472,11 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: corrected data """ - used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) + # used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) + # add "_" at end of all chemical variables to ensure differentiation between e.g. "no" and "no2" + chemical_starter = [f"{v}_" for v in self.chem_vars] + #check if variables from target_dim start with chemical variable type + used_chem_vars = [v for v in data.coords[self.target_dim].values if v.startswith(tuple(chemical_starter))] if len(used_chem_vars) > 0: data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data @@ -502,13 +529,13 @@ class DataHandlerSingleStation(AbstractDataHandler): @staticmethod def _set_file_name(path, station, statistics_per_var): - all_vars = sorted(statistics_per_var.keys()) - return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") + all_vars = sorted(data_sources.get_vars_with_stat_name(statistics_per_var)) + return os.path.join(path, f"{''.join(station)}__{'__'.join(all_vars)}.nc") @staticmethod def _set_meta_file_name(path, station, statistics_per_var): - all_vars = sorted(statistics_per_var.keys()) - return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") + all_vars = sorted(data_sources.get_vars_with_stat_name(statistics_per_var)) + return os.path.join(path, f"{''.join(station)}__{'__'.join(all_vars)}_meta.csv") def interpolate(self, data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, sampling="daily", **kwargs): @@ -583,8 +610,7 @@ class DataHandlerSingleStation(AbstractDataHandler): offset = self.window_history_offset + self.window_history_end self.history = self.shift(data, dim_name_of_shift, window, offset=offset) - def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, - window: int) -> None: + def make_labels(self, dim_name_of_shift: str, window: int) -> None: """ Create a xr.DataArray containing labels. @@ -600,7 +626,7 @@ class DataHandlerSingleStation(AbstractDataHandler): data = self.target_data self.label = self.shift(data, dim_name_of_shift, window) - def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: + def make_observation(self, dim_name_of_shift: str) -> None: """ Create a xr.DataArray containing observations. @@ -683,12 +709,16 @@ class DataHandlerSingleStation(AbstractDataHandler): if transformation is None: return None, None elif isinstance(transformation, dict): - return copy.deepcopy(transformation), copy.deepcopy(transformation) + return copy.deepcopy(self._update_transformation_keys(transformation)), copy.deepcopy(self._update_transformation_keys(transformation)) elif isinstance(transformation, tuple) and len(transformation) == 2: return copy.deepcopy(transformation) else: raise NotImplementedError("Cannot handle this.") + def _update_transformation_keys(self, transformation) -> dict: + return {f"{k}_{correct_stat_name(st)}": v for k, v in transformation.items() for st in helpers.to_list(self.statistics_per_var[k])} + + @staticmethod def check_inverse_transform_params(method: str, mean=None, std=None, min=None, max=None) -> None: """ diff --git a/mlair/helpers/data_sources/__init__.py b/mlair/helpers/data_sources/__init__.py index 6b753bc3afb961be65ff0f934ef4f0de08804a0b..2ae2d78a732bf56dbd0d8e91a6b48d10694aff2c 100644 --- a/mlair/helpers/data_sources/__init__.py +++ b/mlair/helpers/data_sources/__init__.py @@ -8,3 +8,5 @@ __author__ = "Lukas Leufen" __date__ = "2022-07-05" from . import era5, join, toar_data, toar_data_v2 + +from .toar_data import get_vars_with_stat_name, get_target_var_with_stat_name diff --git a/mlair/helpers/data_sources/join.py b/mlair/helpers/data_sources/join.py index a978b2712a83b21f3c1256b2bf0826da63bdda3a..279585bf40bb9f9ec87cef351f3fd4591cfb9d12 100644 --- a/mlair/helpers/data_sources/join.py +++ b/mlair/helpers/data_sources/join.py @@ -66,8 +66,10 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t logging.debug('load: {}'.format(var)) + stats_to_request = convert_to_comma_separated_string(stat_var[var]) + # create data link - opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stat_var[var], + opts = {'base': join_url_base, 'service': 'stats', 'id': vars_dict[var], 'statistics': stats_to_request, 'sampling': sampling, 'capture': 0, 'format': 'json'} # load data @@ -80,7 +82,8 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t data = correct_data_format(data) # correct namespace of statistics - stat = toar_data.correct_stat_name(stat_var[var]) + # stat = toar_data.correct_stat_name(stat_var[var]) + stat = [toar_data.correct_stat_name(s) for s in helpers.to_list(stat_var[var])] # store data in pandas dataframe df = _save_to_pandas(df, data, stat, var) @@ -103,6 +106,20 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t raise toar_data.EmptyQueryResult("No data found in JOIN.") +def convert_to_comma_separated_string(inputpattern: list): + """ + Converts a list to string using a kommata as separator. + + + :param inputpattern: + :type inputpattern: + :return: + :rtype: + """ + return ','.join(helpers.to_list(inputpattern)) + + + def _correct_meta(meta): meta_out = {} for k, v in meta.items(): @@ -336,11 +353,13 @@ def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: s str_format = "%Y-%m-%d %H:%M:%S" else: str_format = "%Y-%m-%d %H:%M" - index = map(lambda s: dt.datetime.strptime(s, str_format), data['datetime']) - if df is None: - df = pd.DataFrame(data[stat], index=index, columns=[var]) - else: - df = pd.concat([df, pd.DataFrame(data[stat], index=index, columns=[var])], axis=1) + index = list(map(lambda s: dt.datetime.strptime(s, str_format), data['datetime'])) + for st in helpers.to_list(stat): + col_name = f"{var}_{st}" + if df is None: + df = pd.DataFrame(data[st], index=index, columns=[col_name]) + else: + df = pd.concat([df, pd.DataFrame(data[st], index=index, columns=[col_name])], axis=1) return df diff --git a/mlair/helpers/data_sources/toar_data.py b/mlair/helpers/data_sources/toar_data.py index 27522855cbe0f3c6f0b78d1598709a694fc7b862..af41aac2663be716c30b6a88f83b581345c73fa1 100644 --- a/mlair/helpers/data_sources/toar_data.py +++ b/mlair/helpers/data_sources/toar_data.py @@ -10,6 +10,7 @@ import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry import pandas as pd +from mlair.helpers import to_list class EmptyQueryResult(Exception): @@ -124,5 +125,42 @@ def correct_stat_name(stat: str) -> str: :return: stat mapped to local namespace """ - mapping = {'average_values': 'mean', 'maximum': 'max', 'minimum': 'min'} + mapping = {'average_values': 'mean', 'maximum': 'max', 'minimum': 'min', + 'perc05': 'p05', 'perc10': 'p10', 'perc25': 'p25', 'perc75': 'p75', + 'perc90': 'p90', 'perc95': 'p95', 'perc98': 'p98', + } return mapping.get(stat, stat) + + +def get_vars_with_stat_name(data: dict, variables: Union[list, None] = None) -> list: + """ + Returns a list of variable names consisting of var_name and applied stats: e.g. o3_dma8eu + :param data: Dictionary containing variable names as keys and aggredation statistic(s) as values + :type data: dict + :param variables: optional variables if only a subset from data.keys() is used + :type variables: list + :return: + :rtype: list[str] + """ + if variables is None: + selected_vars = data.keys() + else: + selected_vars = variables + return [f"{v}_{correct_stat_name(st)}" for v in selected_vars for st in to_list(data[v])] + + +def get_target_var_with_stat_name(data: dict, variable: str, idx_target_stat=0) -> str: + """ + + :param data: + :type data: + :param variable: + :type variable: + :param idx_target_stat: + :type idx_target_stat: + :return: + :rtype: + """ + return get_vars_with_stat_name(data, to_list(variable))[idx_target_stat] + + diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index c7647ef5bf5b5b6c46eae9318c0fd99b294292c6..7540ac50f347b1c037f07fdf18bacd16859fa0bc 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -119,9 +119,15 @@ class PlotMonthlySummary(AbstractPlotClass): # pragma: no cover return min(ahead_steps, window_lead_time) @staticmethod - def _spell_out_chemical_concentrations(short_name: str): + def _spell_out_chemical_concentrations(short_name: str, add_stat_name=False): + short_var_name, stat_name = short_name.split('_') short2long = {'o3': 'ozone', 'no': 'nitrogen oxide', 'no2': 'nitrogen dioxide', 'nox': 'nitrogen dioxides'} - return f"{short2long[short_name]} concentration" + if add_stat_name: + long_name = f"{stat_name} {short2long[short_var_name]} concentration" + else: + long_name = f"{short2long[short_var_name]} concentration" + return long_name + # return f"{short2long[short_name]} concentration" def _plot(self, target_var: str, target_var_unit: str): """ diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index a48a82b25804da34d07807073b0c153408e4e028..86eadfe6dd26f0634e858ee42dee29542e1a2504 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,7 +19,8 @@ import datetime as dt from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope -from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables +from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables, \ + data_sources from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ @@ -83,6 +84,8 @@ class PostProcessing(RunEnvironment): self.forecast_path = self.data_store.get("forecast_path") self.plot_path: str = self.data_store.get("plot_path") self.target_var = self.data_store.get("target_var") + self.target_var_with_stat = data_sources.get_target_var_with_stat_name( + self.data_store.get("statistics_per_var"), self.target_var) self._sampling = self.data_store.get("sampling") self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.skill_scores = None @@ -552,7 +555,7 @@ class PostProcessing(RunEnvironment): PlotFeatureImportanceSkillScore( boot_skill_score, plot_folder=self.plot_path, model_name=self.model_display_name, sampling=self._sampling, ahead_dim=self.ahead_dim, - separate_vars=to_list(self.target_var), bootstrap_type=boot_type, + separate_vars=to_list(self.target_var_with_stat), bootstrap_type=boot_type, bootstrap_method=boot_method, branch_names=branch_names) except Exception as e: logging.error(f"Could not create plot PlotFeatureImportanceSkillScore ({boot_type}, "