# Copyright (c) 2019 Forschungszentrum Juelich GmbH.
# This code is licensed under MIT license (see the LICENSE file for details).

"""
    This program distributes the partitioned MNIST data across multiple ranks
    for truly data distributed training of a shallow ANN for handwritten digit
    classification.

    The Horovod framework is used for seamless distributed training. However,
    instead of distributing epochs, this program distributes data amongst the
    ranks, so that each rank contributes training based on its local subset of
    the training data.

"""

import os
import sys

import mpi4py
import numpy as np
import tensorflow as tf
import horovod.tensorflow.keras as hvd
from tensorflow.python.keras import backend as K

from hpc4ns.errors import MpiInitError
from hpc4ns.distribution import DataDistributor

sys.path.insert(0, '../utils')
from data_utils import DataValidator


def get_filenames(path):
    """
    Returns a list of names of files available on the given path.

    :param path: str. Valid path to an existing directory.

    :return: list. A list of filenames, where each filename is
                   of type str.
    """

    absolute_path = os.path.join(os.path.abspath(f'{path}/x'))

    return os.listdir(absolute_path)


def get_concatenated_data(path, filenames):
    """
    Loads all files with the given filenames from the given path,
    and concatenates all the loaded tensors into one large
    tensor.

    :param path: str. Valid path to an existing directory.
    :param filenames: list. A list of filenames, where each filename is
                   of type str.

    :return: np.ndarray. A tensor with all the loaded content.
    """

    arrays = [
        np.load(os.path.join(path, f)) for f in filenames
    ]

    return np.concatenate(arrays)


def load_dataset(path, filenames):
    """
    Loads the input data and the corresponding labels as
    two np.ndarray types, and returns these as a tuple.

    :param path: str. Valid path to an existing directory.
    :param filenames: list. A list of filenames, where each filename is
                   of type str.

    :return: Tuple consisting two np.ndarray types. The value at
             the first tuple index is the input tensor, while the
             other value is the corresponding array of labels.
    """

    x_dir = os.path.join(os.path.abspath(f'{path}/x'))
    y_dir = os.path.join(os.path.abspath(f'{path}/y'))

    x = get_concatenated_data(x_dir, filenames)
    y = get_concatenated_data(y_dir, filenames)

    return x, y


def initialize_hvd_and_mpi():
    """
    Configure and initialize Horovod and MPI. Also, make sure there
    are no conflicts between Horovod and mpi4py communicator
    initialization.

    :exception: hpc4ns.errors.MpiInitError is raised in the case
                of initialization failure.
    """

    # Initialize Horovod.
    hvd.init()

    # Bind the local rank to a specific GPU, so that each rank uses
    # a different GPU
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.gpu_options.visible_device_list = str(hvd.local_rank())
    K.set_session(tf.Session(config=tf_config))

    # Verify that MPI multi-threading is supported. Horovod cannot work
    # with mpi4py (or any other MPI library) otherwise.
    # More info on MPI multi-threading:
    # https://www.mcs.anl.gov/research/projects/mpi/mpi-standard/mpi-report-2.0/node163.htm#Node163
    if not hvd.mpi_threads_supported():
        raise MpiInitError(
            'MPI multi-threading is not supported. Horovod cannot work with mpi4py'
            'in this case. Please enable MPI multi-threading and try again.'
        )

    # Disable automatic MPI initialization on importing mpi4py.MPI,
    # as we are relying on Horovod to take care of the initialization.
    mpi4py.rc.initialize = False

    # Verify that Horovod and mpi4py are using the same number of ranks
    from mpi4py import MPI
    if hvd.size() != MPI.COMM_WORLD.Get_size():
        raise MpiInitError(
            'Mismatch in hvd.size() and MPI.COMM_WORLD size.'
            f' No. of ranks in Horovod: {hvd.size()}.'
            f' No. of ranks in mpi4py: {MPI.COMM_WORLD.Get_size()}'
        )


def main():
    """ Orchestrates the distributed training program. """

    # Configure and initialize Horovod and mpi4py
    initialize_hvd_and_mpi()

    # Flag to indicate whether this is the MPI root
    is_root = hvd.rank() == 0

    # Decorate the get_filenames function so that instead of returning
    # a list of all filenames, it returns a list of the subset of
    # filenames that are to be processed by the local rank.
    dist_decorator = DataDistributor(
        mpi_comm=mpi4py.MPI.COMM_WORLD, shutdown_on_error=True
    )
    get_rank_local_filenames = dist_decorator(get_filenames)

    # Data directory paths
    data_sub_dir = 'mnist/partitioned'
    data_dir = DataValidator.validated_data_dir(data_sub_dir)

    # Prepare training data
    train_filenames = get_rank_local_filenames(
        f'{os.path.join(data_dir, data_sub_dir)}/train')
    x_train, y_train = load_dataset(
        f'{os.path.join(data_dir, data_sub_dir)}/train', train_filenames)

    # Normalize input samples
    x_train = x_train / 255.0

    if is_root:
        # Prepare test data
        test_filenames = get_filenames(
            f'{os.path.join(data_dir, data_sub_dir)}/test')
        x_test, y_test = load_dataset(
            f'{os.path.join(data_dir, data_sub_dir)}/test', test_filenames)
        x_test = x_test / 255.0
    else:
        x_test, y_test = None, None

    # Define the model, i.e., the network
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])

    # Optimizer
    optimizer = tf.keras.optimizers.Adam()

    # Decorate the optimizer with the Horovod Distributed Optimizer
    optimizer = hvd.DistributedOptimizer(optimizer)

    # Compile the model
    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # Fixed No. of epochs
    epochs = 24

    # Training callbacks
    callbacks = [
        hvd.callbacks.BroadcastGlobalVariablesCallback(0)
    ]

    # Train the model using the training set
    model.fit(
        x=x_train,
        y=y_train,
        batch_size=32,
        epochs=epochs,
        verbose=1 if is_root else 0,
        callbacks=callbacks
    )

    if is_root:
        # Test the model on the test set
        score = model.evaluate(x=x_test, y=y_test, verbose=0)
        print('Test loss:', score[0])
        print('Test accuracy:', score[1])


if __name__ == '__main__':
    main()