diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 4a0bbf68603fa628e34e40cf152400990009ca7c..addf864c3ca2ab88c565eb09e517a93a7960613c 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -62,8 +62,9 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + data_origin = helpers.select_from_dict(self.data_origin, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], - self.store_data_locally, self.data_origin, self.start, self.end) + self.store_data_locally, data_origin, self.start, self.end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) @@ -144,9 +145,10 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi start, end = self.update_start_end(ind) vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + data_origin = helpers.select_from_dict(self.data_origin, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], - self.store_data_locally, self.data_origin, start, end) + self.store_data_locally, data_origin, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) return data @@ -474,9 +476,10 @@ class DataHandlerIFSSingleStation(DataHandlerMixedSamplingWithClimateFirFilterSi start, end = self.update_start_end(ind) vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) + data_origin = helpers.select_from_dict(self.data_origin, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], - self.store_data_locally, self.data_origin, start, end) + self.store_data_locally, data_origin, start, end) if ind == 1: # only for target data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind], limit=self.interpolation_limit[ind], sampling=self.sampling[ind]) diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index e60456c250449aebc5178fe4f38605b31e784f41..a1d3c4aa9ac084ee02828e05b1f068323df80505 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -11,6 +11,7 @@ import dill import hashlib import logging import os +import ast from functools import reduce, partial from typing import Union, List, Iterable, Tuple, Dict, Optional @@ -22,7 +23,7 @@ from mlair.configuration import check_path_and_create from mlair import helpers from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict from mlair.data_handler.abstract_data_handler import AbstractDataHandler -from mlair.helpers import data_sources +from mlair.helpers import data_sources, check_nested_equality # define a more general date type for type hinting date = Union[dt.date, dt.datetime] @@ -299,8 +300,11 @@ class DataHandlerSingleStation(AbstractDataHandler): self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) def make_input_target(self): - data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.store_data_locally, self.data_origin, self.start, self.end) + vars = [self.variables, self.target_var] + stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars) + data_origin = helpers.select_from_dict(self.data_origin, vars) + data, self.meta = self.load_data(self.path, self.station, stats_per_var, self.sampling, + self.store_data_locally, data_origin, self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit, sampling=self.sampling) self.set_inputs_and_targets() @@ -368,14 +372,14 @@ class DataHandlerSingleStation(AbstractDataHandler): Will raise a FileNotFoundError if the values mismatch. """ - check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)} + check_dict = {"data_origin": data_origin, "statistics_per_var": statistics_per_var} for (k, v) in check_dict.items(): if v is None or k not in meta.index: continue - if meta.at[k, station[0]] != v: + m = ast.literal_eval(meta.at[k, station[0]]) + if not check_nested_equality(select_from_dict(m, v.keys()), v): logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " - f"grapping from web.") + f"{m} (local). Raise FileNotFoundError to trigger new grapping from web.") raise FileNotFoundError def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: