__author__ = "Felix Kleinert"
__date__ = '2022-08-05'

import argparse
from mlair.workflows import DefaultWorkflow
# from mlair.model_modules.recurrent_networks import RNN as chosen_model
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_PLOT_LIST
from mlair.model_modules.probability_models import ProbTestModel4, MyUnetProb, ProbTestModel2, ProbTestModelMixture
import os
import tensorflow as tf


def load_stations(case=0):
    import json
    cases = {
        0: 'supplement/station_list_north_german_plain_rural.json',
        1: 'supplement/station_list_north_german_plain.json',
        2: 'supplement/German_background_stations.json',
    }
    try:
        filename = cases[case]
        with open(filename, 'r') as jfile:
            stations = json.load(jfile)
    except FileNotFoundError:
        stations = None
    return stations


def main(parser_args):
    # tf.compat.v1.disable_v2_behavior()
    plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
    stats_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
     'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
     'pblheight': 'maximum'}
    workflow = DefaultWorkflow(  # stations=load_stations(),
        #stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
        stations=load_stations(2),
        model=MyUnetProb,
        window_lead_time=4,
        window_history_size=6,
        epochs=100,
        batch_size=1024,
        train_model=False, create_new_model=True, network="UBA",
        evaluate_feature_importance=False,  # plot_list=["PlotCompetitiveSkillScore"],
        # competitors=["test_model", "test_model2"],
        competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
        variables=list(stats_per_var.keys()),
        statistics_per_var=stats_per_var,
        target_var="o3",
        target_var_unit="ppb",
        **parser_args.__dict__, start_script=__file__)
    workflow.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
                        help="set experiment date as string")
    args = parser.parse_args()
    main(args)