"""
Child class used for configuring the training runscript of the workflow.
"""
__author__ = "Michael Langguth"
__date__ = "2021-01-29"

# import modules
import os, glob
import time
import datetime as dt
import subprocess as sp
from model_modules.model_architectures import known_models
from data_preprocess.dataset_options import known_datasets
from config_utils import Config_runscript_base    # import parent class

class Config_Train(Config_runscript_base):
    cls_name = "Config_Train"#.__name__

    list_models = known_models().keys()
    # !!! Important note !!!
    # As long as we don't have runscript templates for all the datasets listed in known_datasets
    # or a generic template runscript, we need the following manual list
    allowed_datasets = ["era5", "moving_mnist"]  # known_datasets().keys

    def __init__(self, venv_name, lhpc):
        super().__init__(venv_name, lhpc)

        # initialize attributes related to runscript name
        self.long_name_wrk_step = "Training"
        self.rscrpt_tmpl_prefix = "train_model_"
        # initialize additional runscript-specific attributes to be set via keyboard interaction
        self.model = None
        self.destination_dir = None
        # list of variables to be written to runscript
        self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "model", "destination_dir"]
        # copy over method for keyboard interaction
        self.run_config = Config_Train.run_training
    #
    # -----------------------------------------------------------------------------------
    #
    def run_training(self):
        """
        Runs the keyboard interaction for Training
        :return: all attributes of class training are set
        """
        method_name = Config_Train.run_training.__name__

        # decide which dataset is used
        dset_type_req_str = "Enter the name of the dataset on which you want to train:"
        dset_err = ValueError("Please select a dataset from the ones listed above.")

        self.dataset = Config_Train.keyboard_interaction(dset_type_req_str, Config_Train.check_dataset,
                                                         dset_err, ntries=2)

        # get source dir (relative to base_dir_source!)
        self.runscript_template = os.path.join(self.runscript_dir, "train_model_{0}{1}".format(self.dataset, self.suffix_template))
        source_dir_base = Config_Train.handle_source_dir(self, "preprocessedData")

        expdir_req_str = "Choose a subdirectory listed above where the preprocessed TFrecords are located:"
        expdir_err = FileNotFoundError("Could not find any tfrecords.")

        self.source_dir = Config_Train.keyboard_interaction(expdir_req_str, Config_Train.check_expdir,
                                                            expdir_err, ntries=3, suffix2arg=source_dir_base+"/")
        # expand source_dir by tfrecords-subdirectory
        self.source_dir = os.path.join(self.source_dir, "tfrecords")

        # split up directory path in order to retrieve exp_dir used for setting up the destination directory
        exp_dir_split = Config_Train.path_rec_split(self.source_dir)
        index = [idx for idx, s in enumerate(exp_dir_split) if self.dataset in s]
        if index == []:
            raise ValueError(
                    "%{0}: tfrecords found under '{1}', but directory does not seem to reflect naming convention.".format(
                    method_name, self.source_dir))
        exp_dir = exp_dir_split[index[0]]

        # get the model to train
        model_req_str = "Enter the name of the model you want to train:"
        model_err     = ValueError("Please select a model from the ones listed above.")

        self.model = Config_Train.keyboard_interaction(model_req_str, Config_Train.check_model, model_err, ntries=2)

        # experimental ID
        # No need to call keyboard_interaction here, because the user can pass whatever we wants
        self.exp_id = input("*** Enter your desired experimental id (will be extended by timestamp and username):\n")

        # also get current timestamp and user-name...
        timestamp = dt.datetime.now().strftime("%Y%m%dT%H%M%S")
        user_name = os.environ["USER"]
        # ... to construct final destination_dir and exp_dir_ext as well
        self.exp_id = timestamp +"_"+ user_name +"_"+ self.exp_id  # by convention, exp_id is extended by timestamp and username

        # now, we are also ready to set the correct name of the runscript template and the target
        self.runscript_target = "{0}{1}_{2}.sh".format(self.rscrpt_tmpl_prefix, self.dataset, self.exp_id)
        
        base_dir   = Config_Train.get_var_from_runscript(os.path.join(self.runscript_dir, self.runscript_template), "destination_dir")
        exp_dir_ext= os.path.join(exp_dir, self.model, self.exp_id)
        self.destination_dir = os.path.join(base_dir, "models", exp_dir, self.model, self.exp_id)
        
        # sanity check (target_dir is unique):
        if os.path.isdir(self.destination_dir):
            raise IsADirectoryError("%{0}: {1} already exists! Make sure that it is unique.".format(method_name, self.destination_dir))

        # create destination directory...
        os.makedirs(self.destination_dir)

        # Create json-file for data splitting
        source_datasplit = os.path.join("..", "data_split", "datasplit_template.json")
        dest_datasplit = os.path.join(self.destination_dir, "data_split.json")
        # sanity check (default data_split json-file exists)
        if not os.path.isfile(source_datasplit):
            raise FileNotFoundError("%{0}: Could not find default data_split json-file '{1}'".format(method_name, source_datasplit))
        # ...copy over json-file for data splitting...
        os.system("cp "+source_datasplit+" "+dest_datasplit)
        # ...and open vim after some delay
        print("*** Please configure the data splitting:")
        time.sleep(3)
        cmd_vim = os.environ.get('EDITOR', 'vi') + ' ' + os.path.join(self.destination_dir,"data_split.json")
        sp.call(cmd_vim, shell=True)
        sp.call("sed -i '/^#/d' {0}".format(dest_datasplit), shell=True)

        # Create json-file for hyperparameters
        source_hparams = os.path.join("..","hparams", self.dataset, self.model, "model_hparams.json")
        # sanity check (default hyperparameter json-file exists)
        if not os.path.isfile(source_hparams):
            raise FileNotFoundError("%{0}: Could not find default hyperparameter json-file '%{1}'".format(method_name, source_hparams))
        # ...copy over json-file for hyperparamters...
        os.system("cp "+source_hparams+" "+self.destination_dir)
        # ...and open vim after some delay
        print("*** Please configure the model hyperparameters:")
        time.sleep(3)
        cmd_vim = os.environ.get('EDITOR', 'vi') + ' ' + os.path.join(self.destination_dir, "model_hparams.json")
        sp.call(cmd_vim, shell=True)
    #
    # -----------------------------------------------------------------------------------
    #
    @staticmethod
    # dataset used for training
    def check_dataset(dataset_name, silent=False):
        """
        Check if the passed dataset name is known
        :param dataset_name: dataset from keyboard interaction
        :param silent: flag if print-statement are executed
        :return: status with True confirming success
        """
        # NOTE: Templates are only available for ERA5 and moving_MNIST.
        #       After adding further templates or making the template generic,
        #       the latter part of the if-clause can be removed
        #       and further adaptions are required in the configuration chain
        if not dataset_name in Config_Train.allowed_datasets:
            if not silent:
                print("The following dataset can be used for training:")
                for dataset_avail in Config_Train.allowed_datasets: print("* " + dataset_avail)
            return False
        else:
            return True
    #
    # -----------------------------------------------------------------------------------
    #
    @staticmethod
    def check_expdir(exp_dir, silent=False):
        """
        Check if the passed directory path contains TFrecord-files. Note, that the path is extended by tfrecords/
        (e.g. see <base_dir>/model_modules/video_prediction/datasets/era5_dataset.py)
        :param exp_dir: path from keyboard interaction
        :param silent: flag if print-statement are executed
        :return: status with True confirming success
        """
        status = False
        real_dir = os.path.join(exp_dir, "tfrecords")
        if os.path.isdir(real_dir):
            file_list = glob.glob(os.path.join(real_dir, "sequence*.tfrecords"))
            if len(file_list) > 0:
                status = True
            else:
                print("{0} does not contain any tfrecord-files.".format(real_dir))
        else:
            if not silent: print("Passed directory does not exist!")
        return status
    #
    # -----------------------------------------------------------------------------------
    #
    @staticmethod
    def check_model(model_name, silent=False):
        """
        Check if the passed model name is known/available.
        :param model_name: model name from keyboard interaction
        :param silent: flag if print-statement are executed
        :return: status with True confirming success
        """
        if not (model_name in Config_Train.list_models):
            if not silent:
                print("{0} is not a valid model!".format(model_name))
                print("The following models are implemented in the workflow:")
                for model_avail in Config_Train.list_models: print("* " + model_avail)
            return False
        else:
            return True
#
# -----------------------------------------------------------------------------------
#