__author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-11-15' import argparse import logging import os from typing import Union, Dict, Any from src import helpers from src.run_modules.run_environment import RunEnvironment DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042', 'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084', 'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063', 'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034', 'DEBY088', ] DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "plot_conditional_quantiles"] class ExperimentSetup(RunEnvironment): """ params: trainable: Train new model if true, otherwise try to load existing model """ def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None): # create run framework super().__init__() # experiment setup self._set_param("data_path", helpers.prepare_host(sampling=sampling)) self._set_param("create_new_model", create_new_model, default=True) if self.data_store.get("create_new_model", "general"): trainable = True data_path = self.data_store.get("data_path", "general") bootstrap_path = helpers.set_bootstrap_path(bootstrap_path, data_path, sampling) self._set_param("bootstrap_path", bootstrap_path) self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) self._set_param("permute_data", permute_data_on_training, default=False, scope="general.train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date, experiment_path=experiment_path, sampling=sampling) self._set_param("experiment_name", exp_name) self._set_param("experiment_path", exp_path) helpers.check_path_and_create(self.data_store.get("experiment_path", "general")) # set plot path default_plot_path = os.path.join(exp_path, "plots") self._set_param("plot_path", plot_path, default=default_plot_path) helpers.check_path_and_create(self.data_store.get("plot_path", "general")) # set results path default_forecast_path = os.path.join(exp_path, "forecasts") self._set_param("forecast_path", forecast_path, default_forecast_path) helpers.check_path_and_create(self.data_store.get("forecast_path", "general")) # setup for data self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") self._set_param("station_type", station_type, default=None) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys())) self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01", scope="general") self._set_param("end", end, default="2017-12-31", scope="general") self._set_param("window_history_size", window_history_size, default=13) self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="general.preprocessing") self._set_param("sampling", sampling) self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION) self._set_param("transformation", None, scope="general.preprocessing") # target self._set_param("target_var", target_var, default="o3") self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) # interpolation self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']}) self._set_param("interpolate_dim", interpolate_dim, default='datetime') self._set_param("interpolate_method", interpolate_method, default='linear') self._set_param("limit_nan_fill", limit_nan_fill, default=1) # train set parameters self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train") # validation set parameters self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val") # test set parameters self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test") # train_val set parameters self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") # use all stations on all data sets (train, val, test) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) # set post-processing instructions self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing") self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: if value is None and default is not None: value = default self.data_store.set(param, value, scope) logging.debug(f"set experiment attribute: {param}({scope})={value}") @staticmethod def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict: """ Transform args to dict if given as argparse.Namespace :param args: either a dictionary or an argument parser instance :return: dictionary with all arguments """ if isinstance(args, argparse.Namespace): return args.__dict__ elif isinstance(args, dict): return args else: return {} def _compare_variables_and_statistics(self): logging.debug("check if all variables are included in statistics_per_var") stat = self.data_store.get("statistics_per_var", "general") var = self.data_store.get("variables", "general") if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") def _check_target_var(self): target_var = helpers.to_list(self.data_store.get("target_var", "general")) stat = self.data_store.get("statistics_per_var", "general") var = self.data_store.get("variables", "general") if not set(target_var).issubset(stat.keys()): raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") unused_vars = set(stat.keys()).difference(set(var).union(target_var)) if len(unused_vars) > 0: logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") stat_new = helpers.dict_pop(stat, list(unused_vars)) self._set_param("statistics_per_var", stat_new) if __name__ == "__main__": formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.DEBUG) parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, help="set experiment date as string") parser_args = parser.parse_args() with RunEnvironment(): setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])