From c0ac118c9ae6925105c11ae9fe77508fcefbc495 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 8 Mar 2022 10:01:13 +0100 Subject: [PATCH] apply changes from #369 and #370 --- .../data_handler_mixed_sampling.py | 30 ------------------- .../data_handler/data_handler_with_filter.py | 15 ++++++++++ mlair/run_modules/experiment_setup.py | 2 +- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 6c256f86..0bdd9b21 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -283,21 +283,6 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): dh_transformation=dh_transformation[1], **kwargs) return {"filtered": transformation_filtered, "unfiltered": transformation_unfiltered} - def get_X_original(self): - if self.use_filter_branches is True: - X = [] - for data in self._collection: - if hasattr(data, "filter_dim"): - X_total = data.get_X() - filter_dim = data.filter_dim - for filter_name in data.filter_dim_order: - X.append(X_total.sel({filter_dim: filter_name}, drop=True)) - else: - X.append(data.get_X()) - return X - else: - return super().get_X_original() - class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWithClimateFirFilter): # data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation @@ -457,18 +442,3 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi else: # if no unfiltered meteo branch transformation_res["filtered_meteo"] = transformation_meteo return transformation_res if len(transformation_res) > 0 else None - - def get_X_original(self): - if self.use_filter_branches is True: - X = [] - for data in self._collection: - if hasattr(data, "filter_dim"): - X_total = data.get_X() - filter_dim = data.filter_dim - for filter_name in data.filter_dim_order: - X.append(X_total.sel({filter_dim: filter_name}, drop=True)) - else: - X.append(data.get_X()) - return X - else: - return super().get_X_original() diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 997ecbf5..47ccc551 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -116,6 +116,21 @@ class DataHandlerFilter(DefaultDataHandler): self.use_filter_branches = use_filter_branches super().__init__(*args, **kwargs) + def get_X_original(self): + if self.use_filter_branches is True: + X = [] + for data in self._collection: + if hasattr(data, "filter_dim"): + X_total = data.get_X() + filter_dim = data.filter_dim + for filter_name in data.filter_dim_order: + X.append(X_total.sel({filter_dim: filter_name}, drop=True)) + else: + X.append(data.get_X()) + return X + else: + return super().get_X_original() + class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index aca5f583..df797ffc 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -389,7 +389,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("neighbors", ["DEBW030"]) # TODO: just for testing # set competitors - if model_display_name is not None and model_display_name in competitors: + if model_display_name is not None and competitors is not None and model_display_name in competitors: raise IndexError(f"Given model_display_name {model_display_name} is also present in the competitors " f"variable {competitors}. To assure a proper workflow it is required to have unique names " f"for each model and competitor. Please use a different model display name or competitor.") -- GitLab