import re

__author__ = 'Lukas Leufen, Felix Kleinert'
__date__ = '2019-10-21'


import datetime as dt
import logging
import math
import os
import socket
import sys
import time

import keras.backend as K
import xarray as xr

from typing import Dict, Callable


def to_list(arg):
    if not isinstance(arg, list):
        arg = [arg]
    return arg


def check_path_and_create(path):
    try:
        os.makedirs(path)
        logging.debug(f"Created path: {path}")
    except FileExistsError:
        logging.debug(f"Path already exists: {path}")


def l_p_loss(power: int):
    """
    Calculate the L<p> loss for given power p. L1 (p=1) is equal to mean absolute error (MAE), L2 (p=2) is to mean
    squared error (MSE), ...
    :param power: set the power of the error calculus
    :return: loss for given power
    """
    def loss(y_true, y_pred):
        return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
    return loss


class TimeTracking(object):
    """
    Track time to measure execution time. Time tracking automatically starts on initialisation and ends by calling stop
    method. Duration can always be shown by printing the time tracking object or calling get_current_duration.
    """

    def __init__(self, start=True, name="undefined job"):
        self.start = None
        self.end = None
        self._name = name
        if start:
            self._start()

    def _start(self):
        self.start = time.time()
        self.end = None

    def _end(self):
        self.end = time.time()

    def _duration(self):
        if self.end:
            return self.end - self.start
        else:
            return time.time() - self.start

    def __repr__(self):
        # return f"{round(self._duration(), 2)}s"
        return f"{dt.timedelta(seconds=math.ceil(self._duration()))} (hh:mm:ss)"

    def run(self):
        self._start()

    def stop(self, get_duration=False):
        if self.end is None:
            self._end()
        else:
            msg = f"Time was already stopped {time.time() - self.end}s ago."
            raise AssertionError(msg)
        if get_duration:
            return self.duration()

    def duration(self):
        return self._duration()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
        logging.info(f"{self._name} finished after {self}")


def prepare_host(create_new=True, sampling="daily"):
    hostname = socket.gethostname()
    runner_regex = re.compile(r"runner-.*-project-2411-concurrent-?\d+")
    try:
        user = os.getlogin()
    except OSError:
        user = "default"
    if hostname == "ZAM144":
        path = f"/home/{user}/Data/toar_{sampling}/"
    elif hostname == "zam347":
        path = f"/home/{user}/Data/toar_{sampling}/"
    elif hostname == "linux-aa9b":
        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
    elif (len(hostname) > 2) and (hostname[:2] == "jr"):
        path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/"
    elif (len(hostname) > 2) and (hostname[:2] == "jw"):
        path = f"/p/home/jusers/{user}/juwels/intelliaq/DATA/toar_{sampling}/"
    elif runner_regex.match(hostname) is not None:
        path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/"
    else:
        raise OSError(f"unknown host '{hostname}'")
    if not os.path.exists(path):
        try:
            if create_new:
                check_path_and_create(path)
                return path
            else:
                raise PermissionError
        except PermissionError:
            raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.")
    else:
        logging.debug(f"set path to: {path}")
        return path


def set_experiment_name(experiment_date=None, experiment_path=None, sampling=None):

    if experiment_date is None:
        experiment_name = "TestExperiment"
    else:
        experiment_name = f"{experiment_date}_network"
    if sampling == "hourly":
        experiment_name += f"_{sampling}"
    if experiment_path is None:
        experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", experiment_name))
    else:
        experiment_path = os.path.join(os.path.abspath(experiment_path), experiment_name)
    return experiment_name, experiment_path


def set_bootstrap_path(bootstrap_path, data_path, sampling):
    if bootstrap_path is None:
        bootstrap_path = os.path.join(data_path, "..", f"bootstrap_{sampling}")
    check_path_and_create(bootstrap_path)
    return bootstrap_path


class PyTestRegex:
    """Assert that a given string meets some expectations."""

    def __init__(self, pattern: str, flags: int = 0):
        self._regex = re.compile(pattern, flags)

    def __eq__(self, actual: str) -> bool:
        return bool(self._regex.match(actual))

    def __repr__(self) -> str:
        return self._regex.pattern


class PyTestAllEqual:

    def __init__(self, check_list):
        self._list = check_list

    def _check_all_equal(self):
        equal = True
        for b in self._list:
            equal *= xr.testing.assert_equal(self._list[0], b) is None
        return equal == 1

    def is_true(self):
        return self._check_all_equal()


def xr_all_equal(check_list):
    equal = True
    for b in check_list:
        equal *= xr.testing.assert_equal(check_list[0], b) is None
    return equal == 1


def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
    """
    Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>.
    :param d: dictionary with 2D-xarrays
    :param coordinate_name: name of the new created axis (2D -> 3D)
    :return: combined xarray
    """
    xarray = None
    for k, v in d.items():
        if xarray is None:
            xarray = v
            xarray.coords[coordinate_name] = k
        else:
            tmp_xarray = v
            tmp_xarray.coords[coordinate_name] = k
            xarray = xr.concat([xarray, tmp_xarray], coordinate_name)
    return xarray


def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float:
    """
    Perform given rounding operation on number with the precision of decimals.
    :param number: the number to round
    :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value)
    :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python
        built-in round operation.
    :return: rounded number with desired precision
    """
    multiplier = 10. ** decimals
    return round_type(number * multiplier) / multiplier


def list_pop(list_full: list, pop_items):
    pop_items = to_list(pop_items)
    if len(pop_items) > 1:
        return [e for e in list_full if e not in pop_items]
    else:
        l_pop = list_full.copy()
        try:
            l_pop.remove(pop_items[0])
        except ValueError:
            pass
        return l_pop


def dict_pop(dict_orig: Dict, pop_keys):
    pop_keys = to_list(pop_keys)
    return {k: v for k, v in dict_orig.items() if k not in pop_keys}


class Logger:
    """
    Basic logger class to unify all logging outputs. Logs are saved in local file and returned to std output. In default
    settings, logging level of file logger is DEBUG, logging level of stream logger is INFO. Class must be imported
    and initialised in starting script, all subscripts should log with logging.info(), debug, ...
    """

    def __init__(self, log_path=None, level_file=logging.DEBUG, level_stream=logging.INFO):

        # define shared logger format
        self.formatter = '%(asctime)s - %(levelname)s: %(message)s  [%(filename)s:%(funcName)s:%(lineno)s]'

        # set log path
        self.log_file = self.setup_logging_path(log_path)
        # set root logger as file handler
        logging.basicConfig(level=level_file,
                            format=self.formatter,
                            filename=self.log_file,
                            filemode='a')
        # add stream handler to the root logger
        logging.getLogger('').addHandler(self.logger_console(level_stream))
        # print logger path
        logging.info(f"File logger: {self.log_file}")

    @staticmethod
    def setup_logging_path(path: str = None):
        """
        Check if given path exists and creates if not. If path is None, use path from main. The logging file is named
        like `logging_<runtime>.log` where runtime=`%Y-%m-%d_%H-%M-%S` of current run.
        :param path: path to logfile
        :return: path of logfile
        """
        if not path:  # set default path
            path = os.path.join(os.path.dirname(__file__), "..", "logging")
        if not os.path.exists(path):
            os.makedirs(path)
        runtime = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
        log_file = os.path.join(path, f'logging_{runtime}.log')
        return log_file

    def logger_console(self, level: int):
        """
        Defines a stream handler which writes messages of given level or higher to std out
        :param level: logging level as integer, e.g. logging.DEBUG or 10
        :return: defines stream handler
        """
        # define Handler
        console = logging.StreamHandler()
        # set level of Handler
        console.setLevel(level)
        # set a format which is simpler for console use
        formatter = logging.Formatter(self.formatter)
        # tell the handler to use this format
        console.setFormatter(formatter)
        return console