diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index ce96a8f5c039b5a232aa56765209927dd4019168..f30de1bd236d7374a3bff06e218fff3e4c4b0251 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -38,8 +38,10 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data() - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) self.set_inputs_and_targets() self.apply_kz_filter() # this is just a code snippet to check the results of the kz filter diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index f1b5180fc00b19735c461faff72ce7b71cc90401..639e5df2fe8177b0e85adb43e4a575d352f334b3 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -2,67 +2,60 @@ __author__ = 'Lukas Leufen' __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler import DefaultDataHandler from mlair.configuration import path_config +from mlair import helpers +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_SAMPLING import logging import os +import inspect import pandas as pd import xarray as xr -class DataHandlerMixedSampling(DataHandlerSingleStation): +class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + + def __init__(self, *args, sampling_inputs, **kwargs): + sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) + kwargs.update({"sampling": sampling}) + super().__init__(*args, **kwargs) def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data() - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data self.set_inputs_and_targets() if self.do_transformation is True: self.call_transform() self.make_samples() - def load_data(self): - try: - self.read_data_from_disk() - except FileNotFoundError: - self.download_data() - self.load_data() + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data - def read_data_from_disk(self, source_name=""): - """ - Load data and meta data either from local disk (preferred) or download new data by using a custom download method. + def set_inputs_and_targets(self): + inputs = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)}) + targets = self._data[1].sel({self.target_dim: self.target_var}) + self.input_data.data = inputs + self.target_data.data = targets - Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both - cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not - set, it is assumed, that data should be saved locally. - """ - source_name = source_name if len(source_name) == 0 else f" from {source_name}" - path_config.check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() - if self.overwrite_local_data is True: - logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") - if os.path.exists(file_name): - os.remove(file_name) - if os.path.exists(meta_file): - os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file) - logging.debug(f"loaded new data{source_name}") - else: - try: - logging.debug(f"try to load local data from: {file_name}") - data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta() - logging.debug("loading finished") - except FileNotFoundError as e: - logging.debug(e) - logging.debug(f"load new data{source_name}") - data, self.meta = self.download_data(file_name, meta_file) - logging.debug("loading finished") - # create slices and check for negative concentration. - data = self._slice_prep(data) - self._data = self.check_for_negative_concentrations(data) + def setup_data_path(self, data_path, sampling): + """Sets two paths instead of single path. Expects sampling arg to be a list with two entries""" + assert len(sampling) == 2 + return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling)) + + +class DataHandlerMixedSampling(DefaultDataHandler): + """Data handler using mixed sampling for input and target.""" + + data_handler = DataHandlerMixedSamplingSingleStation + data_handler_transformation = DataHandlerMixedSamplingSingleStation + _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index e780c62044ab64cc793e9b4c9baf5e060397a212..cd922e7535124b2d83be2ac9aa3e53f5df949ba6 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -141,9 +141,10 @@ class DataHandlerSingleStation(AbstractDataHandler): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, self.station_type, self.network, - self.store_data_locally) - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) self.set_inputs_and_targets() if self.do_transformation is True: self.call_transform() @@ -179,27 +180,28 @@ class DataHandlerSingleStation(AbstractDataHandler): os.remove(file_name) if os.path.exists(meta_file): os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, - station_type=station_type, network=network, - store_data_locally=store_data_locally) + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally) logging.debug(f"loaded new data") else: try: logging.debug(f"try to load local data from: {file_name}") data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta(station, station_type, network) + meta = pd.read_csv(meta_file, index_col=0) + self.check_station_meta(meta, station, station_type, network) logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) logging.debug(f"load new data") - data, self.meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, - station_type=station_type, network=network, - store_data_locally=store_data_locally) + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally) logging.debug("loading finished") # create slices and check for negative concentration. data = self._slice_prep(data) - self._data = self.check_for_negative_concentrations(data) + data = self.check_for_negative_concentrations(data) + return data, meta @staticmethod def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling, @@ -233,7 +235,8 @@ class DataHandlerSingleStation(AbstractDataHandler): data, meta = self.download_data_from_join(*args, **kwargs) return data, meta - def check_station_meta(self, station, station_type, network): + @staticmethod + def check_station_meta(meta, station, station_type, network): """ Search for the entries in meta data and compare the value with the requested values. @@ -244,9 +247,9 @@ class DataHandlerSingleStation(AbstractDataHandler): for (k, v) in check_dict.items(): if v is None: continue - if self.meta.at[k, station[0]] != v: + if meta.at[k, station[0]] != v: logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " + f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " f"grapping from web.") raise FileNotFoundError @@ -270,8 +273,7 @@ class DataHandlerSingleStation(AbstractDataHandler): data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data - @staticmethod - def setup_data_path(data_path, sampling): + def setup_data_path(self, data_path: str, sampling: str): return os.path.join(os.path.abspath(data_path), sampling) def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: @@ -326,7 +328,8 @@ class DataHandlerSingleStation(AbstractDataHandler): all_vars = sorted(statistics_per_var.keys()) return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") - def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, + @staticmethod + def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): """ Interpolate values according to different methods. @@ -364,8 +367,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: xarray.DataArray """ - self._data = self._data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs) + return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 25cd7c09f49bb8ba41970386a47d7e1936c33ab9..7fb19e29baed5709e30e4069aa3d681f04e38267 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -184,7 +184,7 @@ class ExperimentSetup(RunEnvironment): training) set for a second time to the sample. If multiple valus are given, a sample is added for each exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default, - upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are + upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using ``extremes_on_right_tail_only``. This can be useful for positive skew variables. :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only`` @@ -255,7 +255,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("epochs", epochs, default=DEFAULT_EPOCHS) # set experiment name - sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) + sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling) experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path) self._set_param("experiment_name", experiment_name) diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..47678718cd6edc201585e2dae59b458899782013 --- /dev/null +++ b/run_mixed_sampling.py @@ -0,0 +1,32 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-14' + +import argparse + +from mlair.workflows import DefaultWorkflow +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling + + +def main(parser_args): + args = dict(sampling="daily", + sampling_inputs="hourly", + window_history_size=72, + **parser_args.__dict__, + data_handler=DataHandlerMixedSampling, + start="2006-01-01", + train_start="2006-01-01", + end="2011-12-31", + test_end="2011-12-31", + stations=["DEBW107", "DEBW013"], + epochs=100, + ) + workflow = DefaultWorkflow(**args) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, + help="set experiment date as string") + args = parser.parse_args(["--experiment_date", "testrun"]) + main(args)