diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index c1634e50e5515f08af5f9e49ecb037c836811c7b..b4f7073337f36058c209403fa563ec976b6eec79 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -110,15 +110,13 @@ class DataHandlerSingleStation(AbstractDataHandler): self._data: xr.DataArray = None # loaded raw data self.meta = None # self.variables = sorted(list(statistics_per_var.keys())) if variables is None else variables - if variables is None: - self.variables = sorted( - data_sources.get_vars_with_stat_name(statistics_per_var, variables=None) - ) - - else: - self.variables = sorted( - data_sources.get_vars_with_stat_name(statistics_per_var, variables=variables) - ) + # if variables is None: + # self.variables = sorted( + # data_sources.get_vars_with_stat_name(statistics_per_var, variables=None) + # ) + # + # else: + self.variables = sorted(data_sources.get_vars_with_stat_name(statistics_per_var, variables=variables)) self.history = None self.label = None @@ -320,8 +318,8 @@ class DataHandlerSingleStation(AbstractDataHandler): def set_inputs_and_targets(self): inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)}) targets = self._data.sel( - {self.target_dim: helpers.to_list(self.target_var) - }) # ToDo: is it right to expand this dim?? + {self.target_dim: helpers.to_list(self.target_var) # ToDo: is it right to expand this dim?? + }) self.input_data = inputs self.target_data = targets @@ -332,7 +330,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: :rtype: """ - t_var = helpers.data_sources.get_target_var_with_stat_name(self.statistics_per_var, self._target_var) + t_var = helpers.data_sources.get_single_var_with_stat_name(self.statistics_per_var, self._target_var) return t_var def make_samples(self): diff --git a/mlair/helpers/data_sources/__init__.py b/mlair/helpers/data_sources/__init__.py index 2ae2d78a732bf56dbd0d8e91a6b48d10694aff2c..cd6014a64fc089fb80e47da44d6b1393c7a949b5 100644 --- a/mlair/helpers/data_sources/__init__.py +++ b/mlair/helpers/data_sources/__init__.py @@ -9,4 +9,4 @@ __date__ = "2022-07-05" from . import era5, join, toar_data, toar_data_v2 -from .toar_data import get_vars_with_stat_name, get_target_var_with_stat_name +from .toar_data import get_vars_with_stat_name, get_single_var_with_stat_name diff --git a/mlair/helpers/data_sources/join.py b/mlair/helpers/data_sources/join.py index 30ba31961c459f4bde9d7e4d906246443b46611c..8e7d21ce72ac45e16f1513eed912e4229478093f 100644 --- a/mlair/helpers/data_sources/join.py +++ b/mlair/helpers/data_sources/join.py @@ -11,7 +11,7 @@ import pandas as pd from mlair import helpers from mlair.configuration.join_settings import join_settings from mlair.helpers.data_sources import toar_data, toar_data_v2 - +from mlair.helpers.helpers import convert_to_comma_separated_string # join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/' str_or_none = Union[str, None] @@ -106,20 +106,6 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t raise toar_data.EmptyQueryResult("No data found in JOIN.") -def convert_to_comma_separated_string(inputpattern: list): - """ - Converts a list to string using a kommata as separator. - - - :param inputpattern: - :type inputpattern: - :return: - :rtype: - """ - return ','.join(helpers.to_list(inputpattern)) - - - def _correct_meta(meta): meta_out = {} for k, v in meta.items(): diff --git a/mlair/helpers/data_sources/toar_data.py b/mlair/helpers/data_sources/toar_data.py index af41aac2663be716c30b6a88f83b581345c73fa1..700fbb68c2f9ad41812d8b00c1f9e6c35bbf5687 100644 --- a/mlair/helpers/data_sources/toar_data.py +++ b/mlair/helpers/data_sources/toar_data.py @@ -132,9 +132,9 @@ def correct_stat_name(stat: str) -> str: return mapping.get(stat, stat) -def get_vars_with_stat_name(data: dict, variables: Union[list, None] = None) -> list: +def get_vars_with_stat_name(data: dict, variables: Union[List, None] = None) -> List[str]: """ - Returns a list of variable names consisting of var_name and applied stats: e.g. o3_dma8eu + Returns a list of variable names consisting of var_name and applied stats: e.g. ['o3_dma8eu', 'o3_p95'] :param data: Dictionary containing variable names as keys and aggredation statistic(s) as values :type data: dict :param variables: optional variables if only a subset from data.keys() is used @@ -149,18 +149,18 @@ def get_vars_with_stat_name(data: dict, variables: Union[list, None] = None) -> return [f"{v}_{correct_stat_name(st)}" for v in selected_vars for st in to_list(data[v])] -def get_target_var_with_stat_name(data: dict, variable: str, idx_target_stat=0) -> str: +def get_single_var_with_stat_name(data: dict, variable: str, idx_stat: int = 0) -> str: """ - - :param data: - :type data: - :param variable: - :type variable: - :param idx_target_stat: - :type idx_target_stat: - :return: - :rtype: + Returns combination of variable name and statistic as string (e.g. 'o3_dma8eu') + :param data: Dictionary containing variable names as keys and aggredation statistic(s) as values + :type data: dict + :param variable: Variable that will be extracted from data.keys() + :type variable: string + :param idx_stat: index of list entry from data[variable] to be combined with variable name + :type idx_stat: ind + :return: variable name in combination with statistic (e.g. 'o3_dma8eu') + :rtype: str """ - return get_vars_with_stat_name(data, to_list(variable))[idx_target_stat] + return get_vars_with_stat_name(data, to_list(variable))[idx_stat] diff --git a/mlair/helpers/data_sources/toar_data_v2.py b/mlair/helpers/data_sources/toar_data_v2.py index 0d46229b8069fd124687fc3ed8a0d6ca1ce3122a..745d18ccab1395e006ca3b049cece5d81e949e71 100644 --- a/mlair/helpers/data_sources/toar_data_v2.py +++ b/mlair/helpers/data_sources/toar_data_v2.py @@ -65,15 +65,17 @@ def download_toar(station_name: Union[str, List[str]], stat_var: dict, data_url_base, headers = toar_data_v2_settings(sampling) data_dict = {} + data_coll = [] for var, meta in timeseries_meta.items(): logging.debug(f"load {var}") meta_and_opts = prepare_meta(meta, sampling, stat_var, var) data_var = [] for var_meta, opts in meta_and_opts: data_var.extend(load_timeseries_data(var_meta, data_url_base, opts, headers, sampling)) - data_dict[var] = merge_data(*data_var, sampling=sampling) - # data = pd.DataFrame.from_dict(data_dict) - data = pd.concat(data_dict.values(), axis=1) + # data_dict[var] = merge_data(*data_var, sampling=sampling) + data_coll.append(merge_data(*data_var, sampling=sampling)) + # data = pd.concat(data_dict.values(), axis=1) + data = pd.concat(data_coll, axis=1) data = correct_timezone(data, station_meta, sampling) meta = combine_meta_data(station_meta, {k: v[0] for k, v in timeseries_meta.items()}) diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 7ec262033f781294c1ee71a885b3b243136fa47c..6bd616c6f17081544a2eb379a427b091dca6c9b1 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -16,6 +16,8 @@ from tensorflow.keras.models import Model from tensorflow.python.keras.layers import deserialize, serialize from tensorflow.python.keras.saving import saving_utils +from mlair import helpers + """ The following code is copied from: https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883 and is a hotfix to make keras.model.model models serializable/pickable @@ -310,3 +312,14 @@ def str2bool(v): # elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): # size += sum([get_size(i, seen) for i in obj]) # return size +def convert_to_comma_separated_string(inputpattern: Union[list, str]) -> str: + """ + Converts a list (or string) to string using a comma as separator. + + + :param inputpattern: + :type inputpattern: + :return: + :rtype: + """ + return ','.join(helpers.to_list(inputpattern)) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 86eadfe6dd26f0634e858ee42dee29542e1a2504..bff77e86fadf50d3f0616aaa3c1834715340599d 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -84,7 +84,7 @@ class PostProcessing(RunEnvironment): self.forecast_path = self.data_store.get("forecast_path") self.plot_path: str = self.data_store.get("plot_path") self.target_var = self.data_store.get("target_var") - self.target_var_with_stat = data_sources.get_target_var_with_stat_name( + self.target_var_with_stat = data_sources.get_single_var_with_stat_name( self.data_store.get("statistics_per_var"), self.target_var) self._sampling = self.data_store.get("sampling") self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))