Skip to content
Snippets Groups Projects
Select Git revision
  • f2139653aed7d31bddd5d2a278839ed969c179fa
  • master default protected
  • tf2
  • tf2_pytorch
  • issue_3
  • issue_2
  • 2019a
  • juwels_2019a
  • jureca_2019_a
9 results

submit_job_juron_python2.sh

Blame
  • experiment_setup.py 11.41 KiB
    __author__ = "Lukas Leufen, Felix Kleinert"
    __date__ = '2019-11-15'
    
    
    import argparse
    import logging
    import os
    from typing import Union, Dict, Any
    import socket
    
    
    from src import helpers
    from src.run_modules.run_environment import RunEnvironment
    
    DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004',
                        'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081',
                        'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', 'DEBW042', 'DEBW039', 'DEBY001',
                        'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037', 'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010',
                        'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084', 'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063',
                        'DEBY005', 'DEBW046', 'DEBW103', 'DEBW052', 'DEBW034', 'DEBY088', ]
    DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
                            'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
                            'pblheight': 'maximum'}
    DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"}
    DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                         "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                         "PlotAvailability"]
    DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"]  # ju[wels} #hdfmll(ogin)
    DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"]  # first part of node names for Juwels (jw[comp], hdfmlc(ompute).
    
    
    class ExperimentSetup(RunEnvironment):
        """
        params:
        trainable: Train new model if true, otherwise try to load existing model
        """
    
        def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None,
                     statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
                     window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
                     limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
                     test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
                     experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
                     create_new_model=None, bootstrap_path=None, permute_data_on_training=False, transformation=None,
                     train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None,
                     extremes_on_right_tail_only=None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None,
                     create_new_bootstraps=None, data_path=None, login_nodes=None, hpc_hosts=None):
    
            # create run framework
            super().__init__()
    
            # experiment setup
            self._set_param("data_path", data_path, default=helpers.prepare_host(sampling=sampling))
            self._set_param("hostname", helpers.get_host())
            # self._set_param("hostname", "jwc0123")
            self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
            self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST)
            self._set_param("create_new_model", create_new_model, default=True)
            if self.data_store.get("create_new_model"):
                trainable = True
            data_path = self.data_store.get("data_path")
            bootstrap_path = helpers.set_bootstrap_path(bootstrap_path, data_path, sampling)
            self._set_param("bootstrap_path", bootstrap_path)
            self._set_param("trainable", trainable, default=True)
            self._set_param("fraction_of_training", fraction_of_train, default=0.8)
            self._set_param("extreme_values", extreme_values, default=None, scope="train")
            self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="train")
            self._set_param("upsampling", extreme_values is not None, scope="train")
            upsampling = self.data_store.get("upsampling", "train")
            self._set_param("permute_data", max([permute_data_on_training, upsampling]), scope="train")
    
            # set experiment name
            exp_date = self._get_parser_args(parser_args).get("experiment_date")
            exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date, experiment_path=experiment_path,
                                                             sampling=sampling)
            self._set_param("experiment_name", exp_name)
            self._set_param("experiment_path", exp_path)
            helpers.check_path_and_create(self.data_store.get("experiment_path"))
    
            # set plot path
            default_plot_path = os.path.join(exp_path, "plots")
            self._set_param("plot_path", plot_path, default=default_plot_path)
            helpers.check_path_and_create(self.data_store.get("plot_path"))
    
            # set results path
            default_forecast_path = os.path.join(exp_path, "forecasts")
            self._set_param("forecast_path", forecast_path, default_forecast_path)
            helpers.check_path_and_create(self.data_store.get("forecast_path"))
    
            # setup for data
            self._set_param("stations", stations, default=DEFAULT_STATIONS)
            self._set_param("network", network, default="AIRBASE")
            self._set_param("station_type", station_type, default=None)
            self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
            self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
            self._compare_variables_and_statistics()
            self._set_param("start", start, default="1997-01-01")
            self._set_param("end", end, default="2017-12-31")
            self._set_param("window_history_size", window_history_size, default=13)
            self._set_param("overwrite_local_data", overwrite_local_data, default=False, scope="preprocessing")
            self._set_param("sampling", sampling)
            self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
            self._set_param("transformation", None, scope="preprocessing")
    
            # target
            self._set_param("target_var", target_var, default="o3")
            self._check_target_var()
            self._set_param("target_dim", target_dim, default='variables')
            self._set_param("window_lead_time", window_lead_time, default=3)
    
            # interpolation
            self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
            self._set_param("interpolate_dim", interpolate_dim, default='datetime')
            self._set_param("interpolate_method", interpolate_method, default='linear')
            self._set_param("limit_nan_fill", limit_nan_fill, default=1)
    
            # train set parameters
            self._set_param("start", train_start, default="1997-01-01", scope="train")
            self._set_param("end", train_end, default="2007-12-31", scope="train")
            self._set_param("min_length", train_min_length, default=90, scope="train")
    
            # validation set parameters
            self._set_param("start", val_start, default="2008-01-01", scope="val")
            self._set_param("end", val_end, default="2009-12-31", scope="val")
            self._set_param("min_length", val_min_length, default=90, scope="val")
    
            # test set parameters
            self._set_param("start", test_start, default="2010-01-01", scope="test")
            self._set_param("end", test_end, default="2017-12-31", scope="test")
            self._set_param("min_length", test_min_length, default=90, scope="test")
    
            # train_val set parameters
            self._set_param("start", self.data_store.get("start", "train"), scope="train_val")
            self._set_param("end", self.data_store.get("end", "val"), scope="train_val")
            train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]])
            self._set_param("min_length", train_val_min_length, default=180, scope="train_val")
    
            # use all stations on all data sets (train, val, test)
            self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)
    
            # set post-processing instructions
            self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
            create_new_bootstraps = max([self.data_store.get("trainable", "general"), create_new_bootstraps or False])
            self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing")
            self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
            self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
    
        def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
            if value is None and default is not None:
                value = default
            self.data_store.set(param, value, scope)
            logging.debug(f"set experiment attribute: {param}({scope})={value}")
    
        @staticmethod
        def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict:
            """
            Transform args to dict if given as argparse.Namespace
            :param args: either a dictionary or an argument parser instance
            :return: dictionary with all arguments
            """
            if isinstance(args, argparse.Namespace):
                return args.__dict__
            elif isinstance(args, dict):
                return args
            else:
                return {}
    
        def _compare_variables_and_statistics(self):
            logging.debug("check if all variables are included in statistics_per_var")
            stat = self.data_store.get("statistics_per_var")
            var = self.data_store.get("variables")
            if not set(var).issubset(stat.keys()):
                missing = set(var).difference(stat.keys())
                raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
                                 f"variables are part of statistics_per_var. Please add also information on the missing "
                                 f"statistics for the variables: {missing}")
    
        def _check_target_var(self):
            target_var = helpers.to_list(self.data_store.get("target_var"))
            stat = self.data_store.get("statistics_per_var")
            var = self.data_store.get("variables")
            if not set(target_var).issubset(stat.keys()):
                raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
            unused_vars = set(stat.keys()).difference(set(var).union(target_var))
            if len(unused_vars) > 0:
                logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
                stat_new = helpers.dict_pop(stat, list(unused_vars))
                self._set_param("statistics_per_var", stat_new)
    
    
    if __name__ == "__main__":
    
        formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'
        logging.basicConfig(format=formatter, level=logging.DEBUG)
    
        parser = argparse.ArgumentParser()
        parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
                            help="set experiment date as string")
        parser_args = parser.parse_args()
        with RunEnvironment():
            setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'])