diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index d8991040dfe51246706d132bb64a335475ce7db1..c26f97cdd7ae43a7f6026801aef39435d517c428 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -60,7 +60,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): self.set_inputs_and_targets() def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: - vars = [self.variables, self.target_var][ind] + vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], self.station_type, self.network, self.store_data_locally, self.data_origin, @@ -115,7 +115,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def make_input_target(self): """ - A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + A FIR filter is applied on the input data that has hourly resolution. Labels Y are provided as aggregated values with daily resolution. """ self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data @@ -143,7 +143,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: start, end = self.update_start_end(ind) - vars = [self.variables, self.target_var][ind] + vars = [self.variables, self.target_var] stats_per_var = helpers.select_from_dict(self.statistics_per_var, vars[ind]) data, self.meta = self.load_data(self.path[ind], self.station, stats_per_var, self.sampling[ind], @@ -353,6 +353,7 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_chem") cls.prepare_build(sp_keys, chem_vars, cls.chem_indicator) + cls.correct_overwrite_option(sp_keys) sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) if len(meteo_vars) > 0: cls.set_data_handler_fir_pos(**kwargs) @@ -364,11 +365,18 @@ class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWi sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} sp_keys = cls.build_update_transformation(sp_keys, dh_type="unfiltered_meteo") cls.prepare_build(sp_keys, meteo_vars, cls.meteo_indicator) + cls.correct_overwrite_option(sp_keys) sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args) + @classmethod + def correct_overwrite_option(cls, kwargs): + """Set `overwrite_local_data=False`.""" + if "overwrite_local_data" in kwargs: + kwargs["overwrite_local_data"] = False + @classmethod def set_data_handler_fir_pos(cls, **kwargs): """ diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index d9b1fa8a7914579ace86b764d948e01f915acd2a..6d3407eefd9b4a96f9e73f5f7e21fead0369d37b 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -395,11 +395,11 @@ class DataHandlerSingleStation(AbstractDataHandler): era5_stats, join_stats = statistics_per_var, statistics_per_var # load data - if era5_origin is not None and len(era5_origin) > 0: + if era5_origin is not None and len(era5_stats) > 0: # load era5 data df_era5, meta_era5 = era5.load_era5(station_name=station, stat_var=era5_stats, sampling=sampling, data_origin=era5_origin) - if join_origin is None or len(join_stats.keys()) > 0: + if join_origin is None or len(join_stats) > 0: # load join data df_join, meta_join = join.download_join(station_name=station, stat_var=join_stats, station_type=station_type, network_name=network, sampling=sampling, data_origin=join_origin)