import pytest
import os
import keras
import math
import numpy as np

from src.modules.training import Distributor
from src.data_generator import DataGenerator
from src.inception_model import InceptionModelBase
from src.flatten import flatten_tail


def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
    inception_model = InceptionModelBase()
    conv_settings_dict1 = {
        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
    if add_minor_branch:
        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
    else:
        out = []
    X_in = keras.layers.Dropout(dropout_rate)(X_in)
    out.append(flatten_tail(X_in, 'Main', activation=activation))
    return keras.Model(inputs=X_input, outputs=out)


class TestDistributor:

    @pytest.fixture
    def generator(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def generator_two_stations(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
                             ['o3', 'temp'], 'datetime', 'variables', 'o3',
                             statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def model(self):
        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)

    @pytest.fixture
    def distributor(self, generator, model):
        return Distributor(generator, model)

    def test_init_defaults(self, distributor):
        assert distributor.batch_size == 256
        assert distributor.fit_call is True

    def test_get_model_rank(self, distributor):
        assert distributor._get_model_rank() == 1
        distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
        assert distributor._get_model_rank() == 2
        distributor.model = 1

    def test_get_number_of_mini_batches(self, distributor):
        values = np.zeros((2, 2311, 19))
        assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)

    def test_distribute_on_batches(self,  generator_two_stations, model):
        d = Distributor(generator_two_stations, model)
        for e in d.distribute_on_batches(fit_call=False):
            assert e[0].shape[0] <= d.batch_size
        elements = []
        for i, e in enumerate(d.distribute_on_batches()):
            if i < len(d):
                elements.append(e[0])
            elif i == 2*len(d):  # check if all elements are repeated
                assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
            else:  # break when 3rd iteration starts (is called as infinite loop)
                break

    def test_len(self, distributor):
        assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)

    def test_len_two_stations(self, generator_two_stations, model):
        gen = generator_two_stations
        d = Distributor(gen, model)
        expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
        assert len(d) == expected