From ad811eef236e827f128bcd7b6d4f85f6eaf72dc9 Mon Sep 17 00:00:00 2001 From: "v.gramlich1" <v.gramlichfz-juelich.de> Date: Thu, 2 Sep 2021 08:47:17 +0200 Subject: [PATCH] Add model classes IntelliO3_ts_architecture_finetune_all_dense, IntelliO3_ts_architecture_finetune_outputs and IntelliO3_ts_architecture_finetune_main_output --- mlair/model_modules/model_class.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/mlair/model_modules/model_class.py b/mlair/model_modules/model_class.py index 83434268..29fe727e 100644 --- a/mlair/model_modules/model_class.py +++ b/mlair/model_modules/model_class.py @@ -465,7 +465,7 @@ class IntelliO3_ts_architecture(AbstractModelClass): "loss_weights": [.01, .99] } -class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture): +class IntelliO3_ts_architecture_finetune_all_dense(IntelliO3_ts_architecture): def __init__(self, input_shape: list, output_shape: list): super().__init__(input_shape, output_shape) @@ -478,10 +478,31 @@ class IntelliO3_ts_architecture_freeze(IntelliO3_ts_architecture): for layer in self.model.layers: if not isinstance(layer, keras.layers.core.Dense): layer.trainable = False - ''' + +class IntelliO3_ts_architecture_finetune_outputs(IntelliO3_ts_architecture): + def __init__(self, input_shape: list, output_shape: list): + super().__init__(input_shape, output_shape) + + self.freeze_layers() + self.initial_lr = 1e-5 + self.apply_to_model() + # self.lr_decay = None def freeze_layers(self): for layer in self.model.layers: if layer.name not in ["minor_1_out_Dense", "Main_out_Dense"]: layer.trainable = False - ''' \ No newline at end of file + +class IntelliO3_ts_architecture_finetune_main_output(IntelliO3_ts_architecture): + def __init__(self, input_shape: list, output_shape: list): + super().__init__(input_shape, output_shape) + + self.freeze_layers() + self.initial_lr = 1e-5 + self.apply_to_model() + # self.lr_decay = None + + def freeze_layers(self): + for layer in self.model.layers: + if layer.name not in ["Main_out_Dense"]: + layer.trainable = False \ No newline at end of file -- GitLab