__author__ = "Lukas Leufen"
__date__ = '2019-11-14'


import logging
from src.helpers import TimeTracking
from src import helpers
import argparse
import time


formatter = "%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]"
logging.basicConfig(level=logging.INFO, format=formatter)


class run(object):
    """
    basic run class to measure execution time. Either call this class calling it by 'with' or delete the class instance
    after finishing the measurement. The duration result is logged.
    """

    def __init__(self):
        self.time = TimeTracking()
        logging.info(f"{self.__class__.__name__} started")

    def __del__(self):
        self.time.stop()
        logging.info(f"{self.__class__.__name__} finished after {self.time}")

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def do_stuff(self):
        time.sleep(2)


class ExperimentSetup:
    """
    params:
    trainable: Train new model if true, otherwise try to load existing model
    """

    def __init__(self, trainable=False):
        self.data_path = None
        self.experiment_path = None
        self.experiment_name = None
        self.trainable = None
        self.fraction_of_train = None
        self.use_all_stations_on_all_data_sets = None
        self.setup_experiment(trainable)

    def _set_param(self, param, value):
        setattr(self, param, value)
        logging.debug(f"set attribute: {param}={value}")

    def setup_experiment(self, trainable):

        # set data path of this experiment
        self._set_param("data_path", helpers.prepare_host())

        # set experiment name
        exp_date = args.experiment_date
        exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date)
        self._set_param("experiment_name", exp_name)
        self._set_param("experiment_path", exp_path)
        helpers.check_path_and_create(self.experiment_path)

        # set if model is trainable
        self._set_param("trainable", trainable)

        # set fraction of train
        self._set_param("fraction_of_train", 0.8)

        # use all stations on all data sets (train, val, test)
        self._set_param("use_all_stations_on_all_data_sets", True)


class PreProcessing(run):

    def __init__(self, setup):
        super().__init__()
        self.setup = setup


class Training(run):

    def __init__(self, setup):
        super().__init__()
        self.setup = setup


class PostProcessing(run):

    def __init__(self, setup):
        super().__init__()
        self.setup = setup


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
                        help="set experiment date as string")
    args = parser.parse_args()

    with run():
        exp_setup = ExperimentSetup(trainable=True)

        PreProcessing(exp_setup)

        Training(exp_setup)

        PostProcessing(exp_setup)