diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index e265bd24b4408562653a80ab9f80745246bbfc9c..215c0bb80c05fd9ac267418961c2ca96e025b3e2 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -114,28 +114,29 @@ class PreProcessing(RunEnvironment): Y = xr.concat([Y, station._Y], dim="Stations") Y_extreme = xr.concat([Y_extreme, station._Y_extreme], dim="Stations") - fig, ax = plt.subplots(nrows=2, ncols=2) - fig.suptitle(f"Window Size=1, Bins={bins}, rates_cap={rates_cap}") - Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax[0,0])[0] - Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax[0,0])[0] - ax[0,0].set_title(f"Histogram before-after oversampling") - ax[0,0].legend() - Y_hist_dens = Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax[0,1])[0] - Y_extreme_hist_dens = Y_extreme.plot.hist(bins=bin_edges, density=True, histtype="step", label="After", ax=ax[0,1])[0] - ax[0,1].set_title(f"Density-Histogram before-after oversampling") - ax[0,1].legend() - real_oversampling = Y_extreme_hist/Y_hist - ax[1,0].plot(range(len(real_oversampling)), oversampling_rates_capped, label="Desired oversampling_rates") - ax[1,0].plot(range(len(real_oversampling)), real_oversampling, label="Actual Oversampling Rates") - ax[1,0].set_title(f"Oversampling rates") - ax[1,0].legend() - ax[1,1].plot(range(len(real_oversampling)), real_oversampling / oversampling_rates_capped, - label="Actual/Desired Rate") - ax[1,1].set_title(f"Deviation from desired Oversampling rate") - ax[1,1].legend() - plt.show() - #data[1]._Y.where(data[1]._Y > bin_edges[9], drop=True) - #data[1]._Y_extreme.where(data[1]._Y_extreme > bin_edges[9], drop=True) + ''' + if not on HPC: + fig, ax = plt.subplots(nrows=2, ncols=2) + fig.suptitle(f"Window Size=1, Bins={bins}, rates_cap={rates_cap}") + Y_hist = Y.plot.hist(bins=bin_edges, histtype="step", label="Before", ax=ax[0,0])[0] + Y_extreme_hist = Y_extreme.plot.hist(bins=bin_edges, histtype="step", label="After", ax=ax[0,0])[0] + ax[0,0].set_title(f"Histogram before-after oversampling") + ax[0,0].legend() + Y_hist_dens = Y.plot.hist(bins=bin_edges, density=True, histtype="step", label="Before", ax=ax[0,1])[0] + Y_extreme_hist_dens = Y_extreme.plot.hist(bins=bin_edges, density=True, histtype="step", label="After", ax=ax[0,1])[0] + ax[0,1].set_title(f"Density-Histogram before-after oversampling") + ax[0,1].legend() + real_oversampling = Y_extreme_hist/Y_hist + ax[1,0].plot(range(len(real_oversampling)), oversampling_rates_capped, label="Desired oversampling_rates") + ax[1,0].plot(range(len(real_oversampling)), real_oversampling, label="Actual Oversampling Rates") + ax[1,0].set_title(f"Oversampling rates") + ax[1,0].legend() + ax[1,1].plot(range(len(real_oversampling)), real_oversampling / oversampling_rates_capped, + label="Actual/Desired Rate") + ax[1,1].set_title(f"Deviation from desired Oversampling rate") + ax[1,1].legend() + plt.show() + ''' def report_pre_processing(self): """Log some metrics on data and create latex report.""" diff --git a/run_with_oversampling.py b/run_with_oversampling.py index cbab9b4e579d41b975c3892f528a961341985366..b21e5e6f98df00a5f866c10db280f29e3366b014 100644 --- a/run_with_oversampling.py +++ b/run_with_oversampling.py @@ -9,8 +9,12 @@ from mlair.model_modules.model_class import IntelliO3_ts_architecture import os -def load_stations(): +def load_stations(external_station_list): import json + if external_station_list is None: + filename = 'supplement/station_list_north_german_plain_rural.json' + else: + filename = external_station_list try: filename = 'supplement/station_list_north_german_plain_rural.json' with open(filename, 'r') as jfile: @@ -22,15 +26,14 @@ def load_stations(): def main(parser_args): plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles") - workflow = DefaultWorkflow( # stations=load_stations(), - # stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"], - stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], - train_model=False, create_new_model=True, network="UBA", + workflow = DefaultWorkflow(stations=load_stations('supplement/German_background_station.json')[:75], + #stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], + train_model=True, create_new_model=True, network="UBA", model=IntelliO3_ts_architecture, oversampling_method="bin_oversampling", evaluate_bootstraps=False, # plot_list=["PlotCompetitiveSkillScore"], competitors=["test_model", "test_model2"], competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), - window_lead_time=1, oversampling_bins=10, oversampling_rates_cap=100, + window_lead_time=2, oversampling_bins=10, oversampling_rates_cap=100, **parser_args.__dict__) workflow.run() diff --git a/run_without_oversampling.py b/run_without_oversampling.py index 3c69b4508b644a1b2ef3ffc1b18a9ee6796eac02..5b51ffa0d73c0a92491f0cf52285dd53dcec5cd3 100644 --- a/run_without_oversampling.py +++ b/run_without_oversampling.py @@ -9,8 +9,12 @@ from mlair.model_modules.model_class import IntelliO3_ts_architecture import os -def load_stations(): +def load_stations(external_station_list = None): import json + if external_station_list is None: + filename = 'supplement/station_list_north_german_plain_rural.json' + else: + filename = external_station_list try: filename = 'supplement/station_list_north_german_plain_rural.json' with open(filename, 'r') as jfile: @@ -22,15 +26,14 @@ def load_stations(): def main(parser_args): plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles") - workflow = DefaultWorkflow( # stations=load_stations(), - # stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"], - stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], - train_model=False, create_new_model=True, network="UBA", + workflow = DefaultWorkflow(stations=load_stations('supplement/German_background_station.json')[:75], + #stations=["DEBW013", "DEBW087", "DEBW107", "DEBW076"], + train_model=True, create_new_model=True, network="UBA", model=IntelliO3_ts_architecture, evaluate_bootstraps=False, # plot_list=["PlotCompetitiveSkillScore"], competitors=["test_model", "test_model2"], competitor_path=os.path.join(os.getcwd(), "data", "comp_test"), - window_lead_time=1, oversampling_bins=10, oversampling_rates_cap=100, + window_lead_time=2, oversampling_bins=10, oversampling_rates_cap=100, **parser_args.__dict__) workflow.run()