diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py index 51d4beafbbc0b346331db80567946c3acc702b8e..3da91b18c9af86abaa9492f2bc7ed6c15dc9fe5e 100644 --- a/mlair/configuration/defaults.py +++ b/mlair/configuration/defaults.py @@ -49,6 +49,7 @@ DEFAULT_NUMBER_OF_BOOTSTRAPS = 20 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles", "PlotAvailability"] +DEFAULT_SAMPLING = "daily" def get_defaults(): diff --git a/mlair/configuration/path_config.py b/mlair/configuration/path_config.py index 9b3d6f250d97d93dd1d06004690885f44de30073..bf40c361e121c409efec08b85fdf4e19848049ee 100644 --- a/mlair/configuration/path_config.py +++ b/mlair/configuration/path_config.py @@ -20,7 +20,7 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: :param create_new: Create new path if enabled :param data_path: Parse your custom path (and therefore ignore preset paths fitting to known hosts) - :param sampling: sampling rate to separate data physically by temporal resolution + :param sampling: sampling rate to separate data physically by temporal resolution (deprecated) :return: full path to data """ @@ -32,17 +32,14 @@ def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: data_path = f"/home/{user}/Data/toar_{sampling}/" elif hostname == "zam347": data_path = f"/home/{user}/Data/toar_{sampling}/" - elif hostname == "linux-aa9b": - data_path = f"/home/{user}/mlair/data/toar_{sampling}/" elif (len(hostname) > 2) and (hostname[:2] == "jr"): data_path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/" elif (len(hostname) > 2) and (hostname[:2] in ['jw', 'ju'] or hostname[:5] in ['hdfml']): - data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/toar_{sampling}/" + data_path = f"/p/project/deepacf/intelliaq/{user}/DATA/MLAIR/" elif runner_regex.match(hostname) is not None: - data_path = f"/home/{user}/mlair/data/toar_{sampling}/" + data_path = f"/home/{user}/mlair/data/" else: - data_path = os.path.join(os.getcwd(), "data", sampling) - # raise OSError(f"unknown host '{hostname}'") + data_path = os.path.join(os.getcwd(), "data") if not os.path.exists(data_path): try: @@ -97,7 +94,7 @@ def set_experiment_name(name: str = None, sampling: str = None) -> str: return experiment_name -def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> str: +def set_bootstrap_path(bootstrap_path: str, data_path: str) -> str: """ Set path for bootstrap input data. @@ -105,12 +102,11 @@ def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> st :param bootstrap_path: custom path to store bootstrap data :param data_path: path of data for default bootstrap path - :param sampling: sampling rate to add, if path is set to default :return: full bootstrap path """ if bootstrap_path is None: - bootstrap_path = os.path.join(data_path, "..", f"bootstrap_{sampling}") + bootstrap_path = os.path.join(data_path, "bootstrap") check_path_and_create(bootstrap_path) return os.path.abspath(bootstrap_path) diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index ce96a8f5c039b5a232aa56765209927dd4019168..de1cb071369395edd9a8b6e869d65561dbfa0f11 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -24,7 +24,7 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): - assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution + self._check_sampling(**kwargs) kz_filter_length = to_list(kz_filter_length) kz_filter_iter = to_list(kz_filter_iter) # self.original_data = None # ToDo: implement here something to store unfiltered data @@ -34,12 +34,17 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): self.cutoff_period_days = None super().__init__(*args, **kwargs) + def _check_sampling(self, **kwargs): + assert kwargs.get("sampling") == "hourly" # This data handler requires hourly data resolution + def setup_samples(self): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data() - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) self.set_inputs_and_targets() self.apply_kz_filter() # this is just a code snippet to check the results of the kz filter diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ca5e3389c56b047358f02ba2f78bd0a7f6728f --- /dev/null +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -0,0 +1,97 @@ +__author__ = 'Lukas Leufen' +__date__ = '2020-11-05' + +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler import DefaultDataHandler +from mlair.configuration import path_config +from mlair import helpers +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_SAMPLING + +import logging +import os +import inspect + +import pandas as pd +import xarray as xr + + +class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): + _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) + + def __init__(self, *args, sampling_inputs, **kwargs): + sampling = (sampling_inputs, kwargs.get("sampling", DEFAULT_SAMPLING)) + kwargs.update({"sampling": sampling}) + super().__init__(*args, **kwargs) + + def setup_samples(self): + """ + Setup samples. This method prepares and creates samples X, and labels Y. + """ + self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self.set_inputs_and_targets() + if self.do_transformation is True: + self.call_transform() + self.make_samples() + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data + + def set_inputs_and_targets(self): + inputs = self._data[0].sel({self.target_dim: helpers.to_list(self.variables)}) + targets = self._data[1].sel({self.target_dim: self.target_var}) + self.input_data.data = inputs + self.target_data.data = targets + + def setup_data_path(self, data_path, sampling): + """Sets two paths instead of single path. Expects sampling arg to be a list with two entries""" + assert len(sampling) == 2 + return list(map(lambda x: super(__class__, self).setup_data_path(data_path, x), sampling)) + + +class DataHandlerMixedSampling(DefaultDataHandler): + """Data handler using mixed sampling for input and target.""" + + data_handler = DataHandlerMixedSamplingSingleStation + data_handler_transformation = DataHandlerMixedSamplingSingleStation + _requirements = data_handler.requirements() + + +class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, + DataHandlerKzFilterSingleStation): + _requirements1 = DataHandlerKzFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _check_sampling(self, **kwargs): + assert kwargs.get("sampling") == ("hourly", "daily") + + def setup_samples(self): + """ + Setup samples. This method prepares and creates samples X, and labels Y. + + A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + with daily resolution. + """ + self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self.set_inputs_and_targets() + self.apply_kz_filter() + if self.do_transformation is True: + self.call_transform() + self.make_samples() + + +class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): + """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" + + data_handler = DataHandlerMixedSamplingWithFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 460d1c100dadbc2aea5d43932e902cc080177b27..cd922e7535124b2d83be2ac9aa3e53f5df949ba6 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -52,7 +52,7 @@ class DataHandlerSingleStation(AbstractDataHandler): min_length: int = 0, start=None, end=None, variables=None, **kwargs): super().__init__() # path, station, statistics_per_var, transformation, **kwargs) self.station = helpers.to_list(station) - self.path = os.path.abspath(data_path) + self.path = self.setup_data_path(data_path, sampling) self.statistics_per_var = statistics_per_var self.do_transformation = transformation is not None self.input_data, self.target_data = self.setup_transformation(transformation) @@ -141,8 +141,10 @@ class DataHandlerSingleStation(AbstractDataHandler): """ Setup samples. This method prepares and creates samples X, and labels Y. """ - self.load_data() - self.interpolate(dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) + data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, + self.station_type, self.network, self.store_data_locally) + self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) self.set_inputs_and_targets() if self.do_transformation is True: self.call_transform() @@ -160,7 +162,8 @@ class DataHandlerSingleStation(AbstractDataHandler): self.make_observation(self.target_dim, self.target_var, self.time_dim) self.remove_nan(self.time_dim) - def read_data_from_disk(self, source_name=""): + def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, + store_data_locally=False): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -168,35 +171,42 @@ class DataHandlerSingleStation(AbstractDataHandler): cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not set, it is assumed, that data should be saved locally. """ - source_name = source_name if len(source_name) == 0 else f" from {source_name}" - check_path_and_create(self.path) - file_name = self._set_file_name() - meta_file = self._set_meta_file_name() + check_path_and_create(path) + file_name = self._set_file_name(path, station, statistics_per_var) + meta_file = self._set_meta_file_name(path, station, statistics_per_var) if self.overwrite_local_data is True: - logging.debug(f"overwrite_local_data is true, therefore reload {file_name}{source_name}") + logging.debug(f"overwrite_local_data is true, therefore reload {file_name}") if os.path.exists(file_name): os.remove(file_name) if os.path.exists(meta_file): os.remove(meta_file) - data, self.meta = self.download_data(file_name, meta_file) - logging.debug(f"loaded new data{source_name}") + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally) + logging.debug(f"loaded new data") else: try: logging.debug(f"try to load local data from: {file_name}") data = xr.open_dataarray(file_name) - self.meta = pd.read_csv(meta_file, index_col=0) - self.check_station_meta() + meta = pd.read_csv(meta_file, index_col=0) + self.check_station_meta(meta, station, station_type, network) logging.debug("loading finished") except FileNotFoundError as e: logging.debug(e) - logging.debug(f"load new data{source_name}") - data, self.meta = self.download_data(file_name, meta_file) + logging.debug(f"load new data") + data, meta = self.download_data(file_name, meta_file, station, statistics_per_var, sampling, + station_type=station_type, network=network, + store_data_locally=store_data_locally) logging.debug("loading finished") # create slices and check for negative concentration. data = self._slice_prep(data) - self._data = self.check_for_negative_concentrations(data) + data = self.check_for_negative_concentrations(data) + return data, meta - def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: + @staticmethod + def download_data_from_join(file_name: str, meta_file: str, station, statistics_per_var, sampling, + station_type=None, network=None, store_data_locally=True) -> [xr.DataArray, + pd.DataFrame]: """ Download data from TOAR database using the JOIN interface. @@ -209,36 +219,37 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: downloaded data and its meta data """ df_all = {} - df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, - station_type=self.station_type, network_name=self.network, sampling=self.sampling) - df_all[self.station[0]] = df + df, meta = join.download_join(station_name=station, stat_var=statistics_per_var, station_type=station_type, + network_name=network, sampling=sampling) + df_all[station[0]] = df # convert df_all to xarray xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} xarr = xr.Dataset(xarr).to_array(dim='Stations') - if self.store_data_locally is True: + if store_data_locally is True: # save locally as nc/csv file xarr.to_netcdf(path=file_name) meta.to_csv(meta_file) return xarr, meta - def download_data(self, file_name, meta_file): - data, meta = self.download_data_from_join(file_name, meta_file) + def download_data(self, *args, **kwargs): + data, meta = self.download_data_from_join(*args, **kwargs) return data, meta - def check_station_meta(self): + @staticmethod + def check_station_meta(meta, station, station_type, network): """ Search for the entries in meta data and compare the value with the requested values. Will raise a FileNotFoundError if the values mismatch. """ - if self.station_type is not None: - check_dict = {"station_type": self.station_type, "network_name": self.network} + if station_type is not None: + check_dict = {"station_type": station_type, "network_name": network} for (k, v) in check_dict.items(): if v is None: continue - if self.meta.at[k, self.station[0]] != v: + if meta.at[k, station[0]] != v: logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " - f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"{meta.at[k, station[0]]} (local). Raise FileNotFoundError to trigger new " f"grapping from web.") raise FileNotFoundError @@ -257,10 +268,14 @@ class DataHandlerSingleStation(AbstractDataHandler): """ chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] + # used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) used_chem_vars = list(set(chem_vars) & set(self.variables)) data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data + def setup_data_path(self, data_path: str, sampling: str): + return os.path.join(os.path.abspath(data_path), sampling) + def shift(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray: """ Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). @@ -303,15 +318,18 @@ class DataHandlerSingleStation(AbstractDataHandler): res.name = index_name return res - def _set_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") + @staticmethod + def _set_file_name(path, station, statistics_per_var): + all_vars = sorted(statistics_per_var.keys()) + return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}.nc") - def _set_meta_file_name(self): - all_vars = sorted(self.statistics_per_var.keys()) - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") + @staticmethod + def _set_meta_file_name(path, station, statistics_per_var): + all_vars = sorted(statistics_per_var.keys()) + return os.path.join(path, f"{''.join(station)}_{'_'.join(all_vars)}_meta.csv") - def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, + @staticmethod + def interpolate(data, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): """ Interpolate values according to different methods. @@ -349,8 +367,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: xarray.DataArray """ - self._data = self._data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, - **kwargs) + return data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ @@ -452,25 +469,6 @@ class DataHandlerSingleStation(AbstractDataHandler): """ return data.loc[{coord: slice(str(start), str(end))}] - def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: - """ - Set all negative concentrations to zero. - - Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ - #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", - "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". - - :param data: data array containing variables to check - :param minimum: minimum value, by default this should be 0 - - :return: corrected data - """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(self.statistics_per_var.keys())) - data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) - return data - @staticmethod def setup_transformation(transformation: statistics.TransformationClass): """ @@ -490,13 +488,6 @@ class DataHandlerSingleStation(AbstractDataHandler): else: raise NotImplementedError("Cannot handle this.") - def load_data(self): - try: - self.read_data_from_disk() - except FileNotFoundError: - self.download_data() - self.load_data() - def transform(self, data_class, dim: Union[str, int] = 0, transform_method: str = 'standardise', inverse: bool = False, mean=None, std=None, min=None, max=None) -> None: diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index e6dde10bf6bd13013fa454eadd1a7976c00dd3e2..584151e36fd0c9621d089e88b8ad61cffa0c5925 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -30,7 +30,7 @@ class DefaultDataHandler(AbstractDataHandler): _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) - def __init__(self, id_class: data_handler, data_path: str, min_length: int = 0, + def __init__(self, id_class: data_handler, experiment_path: str, min_length: int = 0, extreme_values: num_or_list = None, extremes_on_right_tail_only: bool = False, name_affix=None, store_processed_data=True): super().__init__() @@ -42,7 +42,7 @@ class DefaultDataHandler(AbstractDataHandler): self._X_extreme = None self._Y_extreme = None _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self)) - self._save_file = os.path.join(data_path, f"data_preparation_{_name_affix}.pickle") + self._save_file = os.path.join(experiment_path, "data", f"{_name_affix}.pickle") self._collection = self._create_collection() self.harmonise_X() self.multiply_extremes(extreme_values, extremes_on_right_tail_only, dim=self.interpolation_dim) diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 5de7ef5f788ddcee591c20f6f9125813cec5205a..7fb19e29baed5709e30e4069aa3d681f04e38267 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -17,7 +17,7 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_TRAIN_START, DEFAULT_TRAIN_END, DEFAULT_TRAIN_MIN_LENGTH, DEFAULT_VAL_START, DEFAULT_VAL_END, \ DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \ - DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST + DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING from mlair.data_handler import DefaultDataHandler from mlair.run_modules.run_environment import RunEnvironment from mlair.model_modules.model_class import MyLittleModel as VanillaModel @@ -184,7 +184,7 @@ class ExperimentSetup(RunEnvironment): training) set for a second time to the sample. If multiple valus are given, a sample is added for each exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default, - upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are + upsampling of extreme values is disabled (`None`). Upsampling can be modified to manifold only values that are actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using ``extremes_on_right_tail_only``. This can be useful for positive skew variables. :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only`` @@ -214,20 +214,25 @@ class ExperimentSetup(RunEnvironment): dimensions=None, time_dim=None, interpolation_method=None, - interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, - test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None, fraction_of_train: float = None, - experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data = None, sampling: str = "daily", - create_new_model = None, bootstrap_path=None, permute_data_on_training = None, transformation=None, + interpolation_limit=None, train_start=None, train_end=None, val_start=None, val_end=None, + test_start=None, + test_end=None, use_all_stations_on_all_data_sets=None, train_model: bool = None, + fraction_of_train: float = None, + experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data=None, + sampling: str = None, + create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, - extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, number_of_bootstraps=None, + extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None, + number_of_bootstraps=None, create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None, - hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, **kwargs): + hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None, sampling_inputs=None, + sampling_outputs=None, **kwargs): # create run framework super().__init__() # experiment setup, hyperparameters - self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling)) + self._set_param("data_path", path_config.prepare_host(data_path=data_path)) self._set_param("hostname", path_config.get_host()) self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST) self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST) @@ -235,7 +240,7 @@ class ExperimentSetup(RunEnvironment): if self.data_store.get("create_new_model"): train_model = True data_path = self.data_store.get("data_path") - bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path, sampling) + bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path) self._set_param("bootstrap_path", bootstrap_path) self._set_param("train_model", train_model, default=DEFAULT_TRAIN_MODEL) self._set_param("fraction_of_training", fraction_of_train, default=DEFAULT_FRACTION_OF_TRAINING) @@ -250,6 +255,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("epochs", epochs, default=DEFAULT_EPOCHS) # set experiment name + sampling = self._set_param("sampling", sampling, default=DEFAULT_SAMPLING) # always related to output sampling experiment_name = path_config.set_experiment_name(name=experiment_date, sampling=sampling) experiment_path = path_config.set_experiment_path(name=experiment_name, path=experiment_path) self._set_param("experiment_name", experiment_name) @@ -287,7 +293,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE) self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA, scope="preprocessing") - self._set_param("sampling", sampling) + self._set_param("sampling_inputs", sampling_inputs, default=sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="preprocessing") self._set_param("data_handler", data_handler, default=DefaultDataHandler) @@ -356,7 +362,7 @@ class ExperimentSetup(RunEnvironment): f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}") def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general", - apply: Callable = None) -> None: + apply: Callable = None) -> Any: """Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value to a list use apply=helpers.to_list).""" if value is None and default is not None: @@ -365,6 +371,7 @@ class ExperimentSetup(RunEnvironment): value = apply(value) self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") + return value def _compare_variables_and_statistics(self): """ diff --git a/run_hourly_kz.py b/run_hourly_kz.py new file mode 100644 index 0000000000000000000000000000000000000000..5536b56e732d81b84dfee7f34bd68d0d2ba49020 --- /dev/null +++ b/run_hourly_kz.py @@ -0,0 +1,31 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-14' + +import argparse + +from mlair.workflows import DefaultWorkflow +from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilter + + +def main(parser_args): + args = dict(sampling="hourly", + window_history_size=24, **parser_args.__dict__, + data_handler=DataHandlerKzFilter, + kz_filter_length=[365 * 24, 20 * 24], # 13,5# , 4 * 24, 12, 6], + kz_filter_iter=[3, 5], # 3,4# , 3, 4, 4], + start="2006-01-01", + train_start="2006-01-01", + end="2011-12-31", + test_end="2011-12-31", + stations=["DEBW107", "DEBW013"] + ) + workflow = DefaultWorkflow(**args) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, + help="set experiment date as string") + args = parser.parse_args(["--experiment_date", "testrun"]) + main(args) diff --git a/run_mixed_sampling.py b/run_mixed_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..56eef7f29872f9ab0ab995935a9008bfdfc6f930 --- /dev/null +++ b/run_mixed_sampling.py @@ -0,0 +1,34 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-11-14' + +import argparse + +from mlair.workflows import DefaultWorkflow +from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, DataHandlerMixedSamplingWithFilter + + +def main(parser_args): + args = dict(sampling="daily", + sampling_inputs="hourly", + window_history_size=72, + **parser_args.__dict__, + data_handler=DataHandlerMixedSampling, # WithFilter, + kz_filter_length=[365 * 24, 20 * 24], + kz_filter_iter=[3, 5], + start="2006-01-01", + train_start="2006-01-01", + end="2011-12-31", + test_end="2011-12-31", + stations=["DEBW107", "DEBW013"], + epochs=100, + ) + workflow = DefaultWorkflow(**args) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, + help="set experiment date as string") + args = parser.parse_args(["--experiment_date", "testrun"]) + main(args) diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py index b97763632922fc2aaffaf267cfbc76ff99e25b6f..2ba80a3bdf62b7fdf10b645da75769435cf7b6b9 100644 --- a/test/test_configuration/test_path_config.py +++ b/test/test_configuration/test_path_config.py @@ -11,22 +11,21 @@ from mlair.helpers import PyTestRegex class TestPrepareHost: - @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", + @mock.patch("socket.gethostname", side_effect=["ZAM144", "zam347", "jrtest", "jwtest", "runner-6HmDp9Qd-project-2411-concurrent-01"]) @mock.patch("getpass.getuser", return_value="testUser") @mock.patch("os.path.exists", return_value=True) def test_prepare_host(self, mock_host, mock_user, mock_path): - assert prepare_host() == "/home/testUser/mlair/data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/" - assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/toar_daily/" - assert prepare_host() == '/home/testUser/mlair/data/toar_daily/' + assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/MLAIR/" + assert prepare_host() == '/home/testUser/mlair/data/' @mock.patch("socket.gethostname", return_value="NotExistingHostName") @mock.patch("getpass.getuser", return_value="zombie21") def test_prepare_host_unknown(self, mock_user, mock_host): - assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data', 'daily') + assert prepare_host() == os.path.join(os.path.abspath(os.getcwd()), 'data') @mock.patch("getpass.getuser", return_value="zombie21") @mock.patch("mlair.configuration.path_config.check_path_and_create", side_effect=PermissionError) @@ -42,13 +41,13 @@ class TestPrepareHost: # assert "does not exist for host 'linux-aa9b'" in e.value.args[0] assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0] - @mock.patch("socket.gethostname", side_effect=["linux-aa9b"]) + @mock.patch("socket.gethostname", side_effect=["zam347"]) @mock.patch("getpass.getuser", return_value="testUser") @mock.patch("os.path.exists", return_value=False) @mock.patch("os.makedirs", side_effect=None) def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check): path = prepare_host() - assert path == "/home/testUser/mlair/data/toar_daily/" + assert path == "/home/testUser/Data/toar_daily/" class TestSetExperimentName: @@ -80,12 +79,12 @@ class TestSetBootstrapPath: @mock.patch("os.makedirs", side_effect=None) def test_bootstrap_path_is_none(self, mock_makedir): - bootstrap_path = set_bootstrap_path(None, 'TestDataPath/', 'daily') - assert bootstrap_path == os.path.abspath('TestDataPath/../bootstrap_daily') + bootstrap_path = set_bootstrap_path(None, 'TestDataPath/') + assert bootstrap_path == os.path.abspath('TestDataPath/bootstrap') @mock.patch("os.makedirs", side_effect=None) def test_bootstap_path_is_given(self, mock_makedir): - bootstrap_path = set_bootstrap_path('Test/path/to/boots', None, None) + bootstrap_path = set_bootstrap_path('Test/path/to/boots', None) assert bootstrap_path == os.path.abspath('./Test/path/to/boots') diff --git a/test/test_run_modules/test_training.py b/test/test_run_modules/test_training.py index c0b625ef70deeb0686b236275e6bd1182ad48d41..c2b58cbd2160bd958c76ba67649ef8caba09fcb4 100644 --- a/test/test_run_modules/test_training.py +++ b/test/test_run_modules/test_training.py @@ -61,7 +61,8 @@ class TestTraining: obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val") obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test") - os.makedirs(path) + if not os.path.exists(path): + os.makedirs(path) obj.data_store.set("experiment_path", path, "general") os.makedirs(batch_path) obj.data_store.set("batch_path", batch_path, "general") @@ -125,7 +126,8 @@ class TestTraining: @pytest.fixture def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var): - data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(os.path.dirname(__file__), 'data'), + data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'), + experiment_path=os.path.join(path, 'exp_path'), statistics_per_var=statistics_per_var, station_type="background", network="AIRBASE", sampling="daily", target_dim="variables", target_var="o3", time_dim="datetime", @@ -169,7 +171,8 @@ class TestTraining: @pytest.fixture def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path): - os.makedirs(path) + if not os.path.exists(path): + os.makedirs(path) os.makedirs(model_path) obj = RunEnvironment() obj.data_store.set("data_collection", data_collection, "general.train")