Skip to content
Snippets Groups Projects
Commit 0f645b8d authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue071_feat_optional-lr-decay' into 'develop'

Resolve "lr_decay should be optional"

See merge request toar/machinelearningtools!73
parents 01d0fca0 d4849364
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!73Resolve "lr_decay should be optional"
Pipeline #32101 passed
......@@ -70,11 +70,12 @@ class ModelSetup(RunEnvironment):
Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the
advanced model checkpoint is added.
"""
lr = self.data_store.get("lr_decay", scope="general.model")
lr = self.data_store.get_default("lr_decay", scope="general.model", default=None)
hist = HistoryAdvanced()
self.data_store.set("hist", hist, scope="general.model")
callbacks = CallbackHandler()
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
if lr:
callbacks.add_callback(lr, self.callbacks_name % "lr", "lr")
callbacks.add_callback(hist, self.callbacks_name % "hist", "hist")
callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss',
save_best_only=True, mode='auto')
......
......@@ -111,7 +111,10 @@ class Training(RunEnvironment):
callbacks=self.callbacks.get_callbacks(as_dict=False),
initial_epoch=initial_epoch)
history = hist
lr = self.callbacks.get_callback_by_name("lr")
try:
lr = self.callbacks.get_callback_by_name("lr")
except IndexError:
lr = None
self.save_callbacks_as_json(history, lr)
self.load_best_model(checkpoint.filepath)
self.create_monitoring_plots(history, lr)
......@@ -148,8 +151,9 @@ class Training(RunEnvironment):
path = self.data_store.get("experiment_path", "general")
with open(os.path.join(path, "history.json"), "w") as f:
json.dump(history.history, f)
with open(os.path.join(path, "history_lr.json"), "w") as f:
json.dump(lr_sc.lr, f)
if lr_sc:
with open(os.path.join(path, "history_lr.json"), "w") as f:
json.dump(lr_sc.lr, f)
def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None:
"""
......@@ -174,4 +178,5 @@ class Training(RunEnvironment):
PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)
# plot learning rate
PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
if lr_sc:
PlotModelLearningRate(filename=os.path.join(path, f"{name}_history_learning_rate.pdf"), lr_sc=lr_sc)
......@@ -4,6 +4,7 @@ import pytest
from src.data_handling.data_generator import DataGenerator
from src.datastore import EmptyScope
from src.model_modules.keras_extensions import CallbackHandler
from src.model_modules.model_class import AbstractModelClass
from src.run_modules.model_setup import ModelSetup
from src.run_modules.run_environment import RunEnvironment
......@@ -61,6 +62,18 @@ class TestModelSetup:
setup.checkpoint_name = "TestName"
setup._set_callbacks()
assert "general.modeltest" in setup.data_store.search_name("callbacks")
callbacks = setup.data_store.get("callbacks", "general.modeltest")
assert len(callbacks.get_callbacks()) == 3
def test_set_callbacks_no_lr_decay(self, setup):
setup.data_store.set("lr_decay", None, "general.model")
assert "general.modeltest" not in setup.data_store.search_name("callbacks")
setup.checkpoint_name = "TestName"
setup._set_callbacks()
callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest")
assert len(callbacks.get_callbacks()) == 2
with pytest.raises(IndexError):
callbacks.get_callback_by_name("lr_decay")
def test_get_model_settings(self, setup_with_model):
with pytest.raises(EmptyScope):
......@@ -73,7 +86,7 @@ class TestModelSetup:
setup_with_gen.build_model()
assert isinstance(setup_with_gen.model, AbstractModelClass)
expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
"optimizer", "lr_decay", "epochs", "batch_size", "activation"}
"optimizer", "epochs", "batch_size", "activation"}
assert expected <= self.current_scope_as_set(setup_with_gen)
def test_set_channels(self, setup_with_gen_tiny):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment