__author__ = "Lukas Leufen"
__date__ = '2020-06-29'

import argparse
# from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
from mlair.data_handler.data_handler_wrf_chem import DataHandlerWRF, DataHandlerMainSectWRF, DataHandlerMainMinorSectWRF
from mlair.workflows import DefaultWorkflow
from mlair.helpers import remove_items
from mlair.configuration.defaults import DEFAULT_PLOT_LIST

from mlair.model_modules.model_class import IntelliO3TsArchitecture, MyLSTMModel, MyCNNModel, MyCNNModelSect, MyLuongAttentionLSTMModel, MyUnet

import os


def load_stations():
    import json
    try:
        filename = 'supplement/WRF_coord_list_from_IntelliO3.json'
        with open(filename, 'r') as jfile:
            stations = json.load(jfile)
    except FileNotFoundError:
        stations = None
    return stations


def main(parser_args):
    do_not_plot = ["PlotDataHistogram", "PlotAvailability"]
    plots = remove_items(DEFAULT_PLOT_LIST, do_not_plot)
    workflow = DefaultWorkflow(  stations=load_stations(),
        lazy_preprocessing=False,
        train_model=False, create_new_model=True, network="UBA",
        evaluate_feature_importance=False,
        feature_importance_bootstrap_type="group_of_variables",
        feature_importance_create_new_bootstraps=False,
        feature_importance_bootstrap_method="zero_mean",
        plot_list=plots,
        #competitors=["NN1s", "sector_baseline"],
        #competitor_path="/p/scratch/deepacf/kleinert1/IASS_proc_monthyl/competitors/o3",
        uncertainty_estimate_block_length="7d",
        train_min_length=1, val_min_length=1, test_min_length=1,
        epochs=300,
        window_lead_time=4,
        window_history_size=6,
        data_handler=DataHandlerWRF,
        data_path = "/p/scratch/deepacf/intelliaq/kleinert1/IASS_proc_monthly/monthly2009_2010-03", 
        #data_path="/p/scratch/deepacf/intelliaq/kleinert1/IASS_proc_monthly/monthly_01-03",
        common_file_starter="wrfout_d01",
        date_format_of_nc_file="%Y-%m",
        time_dim='XTIME',
        #external_coords_file='/p/project/deepacf/inbound_data/IASS_upload/coords.nc',
         external_coords_file="/p/scratch/deepacf/intelliaq/kleinert1/IASS_proc_monthly/coords.nc",
        # external_coords_file="/media/felix/INTENSO/WRF_CHEM/monthly/coords.nc",
        transformation={
            "T2": {"method": "standardise"},
            "Q2": {"method": "standardise"},
            "PBLH": {"method": "standardise"},
            "Ull": {"method": "standardise"},
            "Vll": {"method": "standardise"},
            "wdir10ll": {"method": "min_max", "min": 0., "max": 360.},
            "wspd10ll": {"method": "standardise"},
            'no': {"method": "standardise"},
            'no2': {"method": "standardise"},
            'co': {"method": "standardise"},
            'PSFC': {"method": "standardise"},
            # 'CLDFRA': {"method": "min_max", "min": 0., "max": 1.},
        },
        # variables=['T2', 'o3', 'wdir10ll', 'wspd10ll', 'no', 'no2', 'co', 'PSFC', 'PBLH', 'CLDFRA'],
        variables=['T2', 'o3', 'wdir10ll', 'wspd10ll', 'no', 'no2', 'co', 'PSFC', 'PBLH', 'Q2'],
        target_var='o3',
        target_var_unit="ppb",
        vars_for_unit_conv={'o3': 'ppbv'},
        # statistics_per_var={'T2': None, 'o3': None, 'wdir10ll': None, 'wspd10ll': None,
        #                     'no': None, 'no2': None, 'co': None, 'PSFC': None, 'PBLH': None, 'CLDFRA': None, },
        statistics_per_var={'T2': "average_values", 'o3': "dma8eu", 'wdir10ll': "average_values",
                            'wspd10ll': "average_values", 'no': "dma8eu", 'no2': "dma8eu", 'co': "dma8eu",
                            'PSFC': "average_values", 'PBLH': "average_values", 'Q2': "average_values",
                            # 'CLDFRA': "average_values",
                            },
        # variables=['T2', 'Q2', 'PBLH', 'U10ll', 'V10ll', 'wdir10ll', 'wspd10ll'],
        # target_var=["T2"],
        # statistics_per_var={'T2': None, 'Q2': None, 'PBLH': None,
        #                     'U10ll': None, 'V10ll': None, 'wdir10ll': None, 'wspd10ll': None},
        wind_sectors=['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'],
        var_logical_z_coord_selector=0,
        targetvar_logical_z_coord_selector=0,
        aggregation_dim='bottom_top',

        radius=200,  # km

        start='2009-01-01',
        # end='2009-01-04',
        #end='2009-01-31',
        end='2010-03-31',
        
        #train_start='2009-01-01',
        #train_end='2009-01-02',
        train_start='2009-01-01',
        #train_end='2009-01-15',
        train_end='2009-10-15',
        
        #val_start='2009-01-02',
        #val_end='2009-01-03',
        ###################################
        #val_start='2009-01-15',
        #val_end='2009-01-22',
        ###################################
        val_start='2009-10-16',
        val_end='2009-12-14',

        #test_start='2009-01-03',
        #test_end='2009-01-04',
        ###################################
        #test_start='2009-01-22',
        #test_end='2009-01-31',
        ###################################
        test_start='2009-12-15',
        test_end='2010-03-31',
        
        # sampling='hourly',
        sampling="daily",
        input_output_sampling4toarstats=("hourly", "daily"),
        time_zone="UTC",
        target_time_type="solar_time",
        use_multiprocessing=True,
        
        batch_size=64*2*2,
        interpolation_limit=0,
        as_image_like_data_format=True,
        model=MyUnet,

        **parser_args.__dict__)
    workflow.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
                        help="set experiment date as string")
    args = parser.parse_args()
    main(args)