diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index a5c6a8e37a16f6a1e5d9d086c1995caeb3dbeb27..834342689a160ca40b99963de9731f7d1ad47f51 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -364,17 +364,18 @@ class IntelliO3_ts_architecture(AbstractModelClass): assert len(output_shape) == 1 super().__init__(input_shape[0], output_shape[0]) - from mlair.model_modules.keras_extensions import LearningRateDecay - # settings self.dropout_rate = .35 self.regularizer = keras.regularizers.l2(0.01) self.initial_lr = 1e-4 - self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.activation = keras.layers.ELU self.padding = "SymPad2D" + self.apply_to_model() # apply to model + def apply_to_model(self): + from mlair.model_modules.keras_extensions import LearningRateDecay + self.lr_decay = LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) self.set_model() self.set_compile_options() self.set_custom_objects(loss=self.compile_options["loss"][0], @@ -467,9 +468,12 @@ class IntelliO3_ts_architecture(AbstractModelClass): class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture): def __init__(self, input_shape: list, output_shape: list): super().__init__(input_shape, output_shape) + self.freeze_layers() self.initial_lr = 1e-5 - ''' + self.apply_to_model() + # self.lr_decay = None + def freeze_layers(self): for layer in self.model.layers: if not isinstance(layer, keras.layers.core.Dense): @@ -479,4 +483,5 @@ class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture): def freeze_layers(self): for layer in self.model.layers: if layer.name not in ["minor_1_out_Dense", "Main_out_Dense"]: - layer.trainable = False \ No newline at end of file + layer.trainable = False + ''' \ No newline at end of file diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 6eba642b44bfee786529240cd21176b3c8cc039a..2c8fd6e4935e063a8026ebdc6cb1309d4d8f11fa 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -30,7 +30,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) @TimeTrackingWrapper class PlotContingency(AbstractPlotClass): - #Todo: Get min and max_label def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn", obs_name: str = "obs", comp_names: str = "IntelliO3", @@ -90,7 +89,7 @@ class PlotContingency(AbstractPlotClass): else: for type in data.type.values.tolist(): plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], label=type) - plt.title(self._plot_names[self._plot_counter][13:]) + plt.title(self._plot_names[self._plot_counter]) plt.legend() self.plot_name = self._plot_names[self._plot_counter] self._plot_counter = self._plot_counter + 1 diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index 83f4a2bd96314d6f8c53f8cc9407cbc12e7b9a16..28a4e99abe83102e632ec98000ae9233362dbe73 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -83,7 +83,10 @@ class ModelSetup(RunEnvironment): self.plot_model() # load weights if no training shall be performed - if not self._train_model and not self._create_new_model: + external_weights = self.data_store.get("external_weights") + if external_weights is not None: + self.load_weights(external_weights) + elif not self._train_model and not self._create_new_model: self.load_weights() # create checkpoint @@ -131,11 +134,13 @@ class ModelSetup(RunEnvironment): save_best_only=True, mode='auto') self.data_store.set("callbacks", callbacks, self.scope) - def load_weights(self): + def load_weights(self, external_weight=None): """Try to load weights from existing model or skip if not possible.""" + if external_weight is None: + external_weight = self.model_name try: - self.model.load_weights(self.model_name) - logging.info(f"reload weights from model {self.model_name} ...") + self.model.load_weights(external_weight) + logging.info(f"reload weights from model {external_weight} ...") except OSError: logging.info('no weights to reload...') diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py index 9cafbbf6b9c8792fddacf88fac3a52261b9b4c5f..00e8eae1581453666d3ca11f48fcdaedf6a24ad0 100644 --- a/mlair/run_modules/training.py +++ b/mlair/run_modules/training.py @@ -171,11 +171,7 @@ class Training(RunEnvironment): except IndexError: epo_timing = None self.save_callbacks_as_json(history, lr, epo_timing) - external_weights = self.data_store.get("external_weights") - if external_weights is not None: - self.load_best_model(external_weights) - else: - self.load_best_model(checkpoint.filepath) + self.load_best_model(checkpoint.filepath) self.create_monitoring_plots(history, lr) def save_model(self) -> None: diff --git a/run_with_finetuning.py b/run_with_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..65611793a361717add42155ed914bfc7c7e71321 --- /dev/null +++ b/run_with_finetuning.py @@ -0,0 +1,45 @@ +__author__ = "Lukas Leufen" +__date__ = '2020-06-29' + +import argparse +from mlair.workflows import DefaultWorkflow +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_PLOT_LIST +from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_freeze +import os + + +def load_stations(): + import json + try: + filename = 'supplement/station_list_north_german_plain_rural.json' + with open(filename, 'r') as jfile: + stations = json.load(jfile) + except FileNotFoundError: + stations = None + return stations + + +def main(parser_args): + plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles") + workflow = DefaultWorkflow( # stations=load_stations(), + stations=["DEBW087", "DEBW013", "DEBW107", "DEBW076"], + #stations=["DEBW013", "DEBW087"], + epochs=1, external_weights="/home/vincentgramlich/mlair/data/weights/testrun_network_daily_model-best.h5", + train_model=True, create_new_model=True, network="UBA", + model=IntelliO3_ts_architecture_freeze, + window_lead_time=1, + #oversampling_method="bin_oversampling", oversampling_bins=10, oversampling_rates_cap=100, window_lead_time=2, + evaluate_bootstraps=False, plot_list=["PlotContingency"], + competitors=["withoutfinetuning"], + competitor_path=os.path.join(os.getcwd(), "data", "competitors", "o3"), + **parser_args.__dict__, start_script=__file__) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun", + help="set experiment date as string") + args = parser.parse_args() + main(args) diff --git a/run_without_finetuning.py b/run_without_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..61644a33195f8c5124b61f164c4e5831081ce0fb --- /dev/null +++ b/run_without_finetuning.py @@ -0,0 +1,45 @@ +__author__ = "Lukas Leufen" +__date__ = '2020-06-29' + +import argparse +from mlair.workflows import DefaultWorkflow +from mlair.helpers import remove_items +from mlair.configuration.defaults import DEFAULT_PLOT_LIST +from mlair.model_modules.model_class import IntelliO3_ts_architecture, IntelliO3_ts_architecture_freeze +import os + + +def load_stations(): + import json + try: + filename = 'supplement/station_list_north_german_plain_rural.json' + with open(filename, 'r') as jfile: + stations = json.load(jfile) + except FileNotFoundError: + stations = None + return stations + + +def main(parser_args): + plots = remove_items(DEFAULT_PLOT_LIST, "PlotConditionalQuantiles") + workflow = DefaultWorkflow( # stations=load_stations(), + stations=["DEBW087","DEBW013", "DEBW107", "DEBW076"], + #stations=["DEBW013", "DEBW087"], + epochs=5, #external_weights="/home/vincentgramlich/mlair/data/weights/test_weight.h5", + train_model=True, create_new_model=True, network="UBA", + model=IntelliO3_ts_architecture, + window_lead_time=1, + #oversampling_method="bin_oversampling", oversampling_bins=10, oversampling_rates_cap=100, window_lead_time=2, + evaluate_bootstraps=False, plot_list=plots, + #competitors=["testcompetitor", "testcompetitor2"], + competitor_path=os.path.join(os.getcwd(), "data", "competitors"), + **parser_args.__dict__, start_script=__file__) + workflow.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default="testrun", + help="set experiment date as string") + args = parser.parse_args() + main(args) \ No newline at end of file