__author__ = "Lukas Leufen"
__date__ = '2020-06-29'

import argparse
from mlair.workflows import DefaultWorkflow
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_PLOT_LIST
from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_finetune_all_dense, \
    IntelliO3_ts_architecture_finetune_outputs, IntelliO3_ts_architecture_finetune_main_output
import os


def load_stations():
    import json
    try:
        filename = 'supplement/station_list_north_german_plain_rural.json'
        with open(filename, 'r') as jfile:
            stations = json.load(jfile)
    except FileNotFoundError:
        stations = None
    return stations


def main(parser_args):
    plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles")
    workflow = DefaultWorkflow(  # stations=load_stations(),
        stations=["DEBW087", "DEBW013", "DEBW107",  "DEBW076"],
        #stations=["DEBW013", "DEBW087"],
        epochs=1, external_weights="/home/vincentgramlich/mlair/data/weights/testrun_network_daily_model-best.h5",
        train_model=True, create_new_model=True, network="UBA",
        model=IntelliO3_ts_architecture_finetune_all_dense,
        window_lead_time=1,
        #oversampling_method="bin_oversampling", oversampling_bins=10, oversampling_rates_cap=100, window_lead_time=2,
        evaluate_bootstraps=False,  plot_list=["PlotContingency"],
        competitors=["withoutfinetuning"],
        competitor_path=os.path.join(os.getcwd(), "data", "competitors", "o3"),
        **parser_args.__dict__, start_script=__file__)
    workflow.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
                        help="set experiment date as string")
    args = parser.parse_args()
    main(args)