import keras

from src.inception_model import InceptionModelBase
from src.flatten import flatten_tail


def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
    inception_model = InceptionModelBase()
    conv_settings_dict1 = {
        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
    if add_minor_branch:
        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
    else:
        out = []
    X_in = keras.layers.Dropout(dropout_rate)(X_in)
    out.append(flatten_tail(X_in, 'Main', activation=activation))
    return keras.Model(inputs=X_input, outputs=out)