From e31a8e5d17035079386fda5989a85173e0861cca Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 23 Jun 2023 14:59:26 +0200
Subject: [PATCH] adjust check meta to check only relevant parameters

---
 .../data_handler_mixed_sampling.py             |  9 ++++++---
 .../data_handler_single_station.py             | 18 +++++++++++-------
 2 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 4a0bbf6..addf864 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -62,8 +62,9 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
     def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
         vars = [self.variables, self.target_var]
         stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
+        data_origin = helpers.select_from_dict(self.data_origin, vars[ind])
         data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
-                                         self.store_data_locally, self.data_origin, self.start, self.end)
+                                         self.store_data_locally, data_origin, self.start, self.end)
         data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
                                 limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
 
@@ -144,9 +145,10 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
         start, end = self.update_start_end(ind)
         vars = [self.variables, self.target_var]
         stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
+        data_origin = helpers.select_from_dict(self.data_origin, vars[ind])
 
         data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
-                                         self.store_data_locally, self.data_origin, start, end)
+                                         self.store_data_locally, data_origin, start, end)
         data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
                                 limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
         return data
@@ -474,9 +476,10 @@ class DataHandlerIFSSingleStation(DataHandlerMixedSamplingWithClimateFirFilterSi
         start, end = self.update_start_end(ind)
         vars = [self.variables, self.target_var]
         stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind])
+        data_origin = helpers.select_from_dict(self.data_origin, vars[ind])
 
         data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind],
-                                         self.store_data_locally, self.data_origin, start, end)
+                                         self.store_data_locally, data_origin, start, end)
         if ind == 1:  # only for target
             data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method[ind],
                                     limit=self.interpolation_limit[ind], sampling=self.sampling[ind])
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index e60456c..a1d3c4a 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -11,6 +11,7 @@ import dill
 import hashlib
 import logging
 import os
+import ast
 from functools import reduce, partial
 from typing import Union, List, Iterable, Tuple, Dict, Optional
 
@@ -22,7 +23,7 @@ from mlair.configuration import check_path_and_create
 from mlair import helpers
 from mlair.helpers import statistics, TimeTrackingWrapper, filter_dict_by_value, select_from_dict
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
-from mlair.helpers import data_sources
+from mlair.helpers import data_sources, check_nested_equality
 
 # define a more general date type for type hinting
 date = Union[dt.date, dt.datetime]
@@ -299,8 +300,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
 
     def make_input_target(self):
-        data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling,
-                                         self.store_data_locally, self.data_origin, self.start, self.end)
+        vars = [self.variables, self.target_var]
+        stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars)
+        data_origin = helpers.select_from_dict(self.data_origin, vars)
+        data, self.meta = self.load_data(self.path, self.station, stats_per_var, self.sampling,
+                                         self.store_data_locally, data_origin, self.start, self.end)
         self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method,
                                       limit=self.interpolation_limit, sampling=self.sampling)
         self.set_inputs_and_targets()
@@ -368,14 +372,14 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
         Will raise a FileNotFoundError if the values mismatch.
         """
-        check_dict = {"data_origin": str(data_origin), "statistics_per_var": str(statistics_per_var)}
+        check_dict = {"data_origin": data_origin, "statistics_per_var": statistics_per_var}
         for (k, v) in check_dict.items():
             if v is None or k not in meta.index:
                 continue
-            if meta.at[k, station[0]] != v:
+            m = ast.literal_eval(meta.at[k, station[0]])
+            if not check_nested_equality(select_from_dict(m, v.keys()), v):
                 logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != "
-                              f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new "
-                              f"grapping from web.")
+                              f"{m} (local). Raise FileNotFoundError to trigger new grapping from web.")
                 raise FileNotFoundError
 
     def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray:
-- 
GitLab