Skip to content
Snippets Groups Projects
Select Git revision
  • 3b924699042742c86b1d1af5ca90d35ed2ab6ccb
  • master default
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
28 results

main_train_models.py

Blame
  • run.py 1.53 KiB
    __author__ = "Lukas Leufen"
    __date__ = '2020-06-29'
    
    import argparse
    from mlair.workflows import DefaultWorkflow
    # from mlair.model_modules.recurrent_networks import RNN as chosen_model
    from mlair.helpers import remove_items
    from mlair.configuration.defaults import DEFAULT_PLOT_LIST
    import os
    import tensorflow as tf
    
    
    def load_stations():
        import json
        try:
            filename = 'supplement/station_list_north_german_plain_rural.json'
            with open(filename, 'r') as jfile:
                stations = json.load(jfile)
        except FileNotFoundError:
            stations = None
        return stations
    
    
    def main(parser_args):
        # tf.compat.v1.disable_v2_behavior()
        plots = remove_items(DEFAULT_PLOT_LIST, ["PlotConditionalQuantiles", "PlotPeriodogram"])
        workflow = DefaultWorkflow(  # stations=load_stations(),
            # stations=["DEBW087","DEBW013", "DEBW107",  "DEBW076"],
            stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"],
            train_model=False, create_new_model=True, network="UBA",
            evaluate_feature_importance=False,  # plot_list=["PlotCompetitiveSkillScore"],
            competitors=["test_model", "test_model2"],
            competitor_path=os.path.join(os.getcwd(), "data", "comp_test"),
            **parser_args.__dict__, start_script=__file__)
        workflow.run()
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun",
                            help="set experiment date as string")
        args = parser.parse_args()
        main(args)