diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 80890b6f45dcde80aa75e9203a4a44ba25c7db01..c16dfb4344b6d083876081b333f855a9eac99c6b 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -36,7 +36,9 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): self.make_samples() def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: - data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + vars = [self.variables, self.target_var] + stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, self.start, self.end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, @@ -110,7 +112,10 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi else: # target start, end = self.start, self.end - data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + vars = [self.variables, self.target_var] + stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + + data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index e554a3b32d8e4e2f5482a388374cfba87f7add15..8131566acb9d5456832770152d8776fa6827d6f4 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -271,7 +271,7 @@ class DataHandlerSingleStation(AbstractDataHandler): chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) - used_chem_vars = list(set(chem_vars) & set(self.variables)) + used_chem_vars = list(set(chem_vars) & set(data.variables.values)) data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data diff --git a/mlair/helpers/__init__.py b/mlair/helpers/__init__.py index 9e2f612c86dc0477693567210493fbdcf3002954..4671334c16267be819ab8ee0ad96b7135ee01531 100644 --- a/mlair/helpers/__init__.py +++ b/mlair/helpers/__init__.py @@ -3,4 +3,4 @@ from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger -from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value +from .helpers import remove_items, float_round, dict_to_xarray, to_list, extract_value, select_from_dict diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 3ecf1f6213bf39d2e3571a1b451173b981a3dadf..36470ebc1c3a008c0f6ecca11478d83d6fa57cec 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -99,6 +99,19 @@ def remove_items(obj: Union[List, Dict], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") +def select_from_dict(dict_obj: dict, sel_list: str): + """ + Extract all key values pairs whose key is contained in the sel_list. + + Does not perform a check if all elements of sel_list are keys of dict_obj. Therefore the number of pairs in the + returned dict is always smaller or equal to the number of elements in the sel_list. + """ + sel_list = to_list(sel_list) + assert isinstance(dict_obj, dict) + sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + return sel_dict + + def extract_value(encapsulated_value): try: return extract_value(encapsulated_value[0])