Skip to content
Snippets Groups Projects
Commit 40bc20bd authored by lukas leufen's avatar lukas leufen
Browse files

can run training and save callbacks

parent 389504fa
No related branches found
No related tags found
2 merge requests!24include recent development,!20not distributed training
Pipeline #27179 passed
......@@ -9,7 +9,8 @@ from src.modules.experiment_setup import ExperimentSetup
from src.modules.run_environment import RunEnvironment
from src.modules.pre_processing import PreProcessing
from src.modules.model_setup import ModelSetup
from src.modules.modules import Training, PostProcessing
from src.modules.training import Training
from src.modules.modules import PostProcessing
def main(parser_args):
......@@ -32,8 +33,8 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None,
parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None,
help="set experiment date as string")
args = parser.parse_args()
args = parser.parse_args(["--experiment_date", "testrun"])
main(args)
......@@ -178,11 +178,11 @@ def set_experiment_name(experiment_date=None, experiment_path=None):
if experiment_date is None:
experiment_name = "TestExperiment"
else:
experiment_name = f"{experiment_date}_network/"
experiment_name = f"{experiment_date}_network"
if experiment_path is None:
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", experiment_name))
else:
experiment_path = os.path.abspath(experiment_path)
experiment_path = os.path.join(os.path.abspath(experiment_path), experiment_name)
return experiment_name, experiment_path
......
......@@ -24,7 +24,9 @@ class ModelSetup(RunEnvironment):
# create run framework
super().__init__()
self.model = None
self.model_name = self.data_store.get("experiment_name", "general") + "model-best.h5"
path = self.data_store.get("experiment_path", "general")
exp_name = self.data_store.get("experiment_name", "general")
self.model_name = os.path.join(path, f"{exp_name}_model-best.h5")
self.scope = "general.model"
self._run()
......@@ -74,7 +76,7 @@ class ModelSetup(RunEnvironment):
def plot_model(self): # pragma: no cover
with tf.device("/cpu:0"):
path = self.data_store.get("experiment_path", "general")
name = self.data_store.get("experiment_name", "general") + "model.pdf"
name = self.data_store.get("experiment_name", "general") + "_model.pdf"
file_name = os.path.join(path, name)
keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
......@@ -100,7 +102,7 @@ class ModelSetup(RunEnvironment):
self.data_store.put("lr_decay", LearningRateDecay(base_lr=initial_lr, drop=.94, epochs_drop=10), self.scope)
# learning settings
self.data_store.put("epochs", 2, self.scope)
self.data_store.put("epochs", 10, self.scope)
self.data_store.put("batch_size", int(256), self.scope)
# activation
......
......@@ -3,6 +3,8 @@ __author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-05'
import logging
import os
import json
from src.modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor
......@@ -13,13 +15,20 @@ class Training(RunEnvironment):
def __init__(self):
super().__init__()
self.model = self.data_store.get("model", "general.model")
self.train_generator = None
self.val_generator = None
self.test_generator = None
self.train_set = None
self.val_set = None
self.test_set = None
self.batch_size = self.data_store.get("batch_size", "general.model")
self.epochs = self.data_store.get("epochs", "general.model")
self.checkpoint = self.data_store.get("checkpoint", "general.model")
self.lr_sc = self.data_store.get("epochs", "general.model")
self.lr_sc = self.data_store.get("lr_decay", "general.model")
self.experiment_name = self.data_store.get("experiment_name", "general")
self._run()
def _run(self):
self.set_generators()
self.make_predict_function()
self.train()
def make_predict_function(self):
# create the predict function before distributing. This is necessary, because tf will compile the predict
......@@ -30,19 +39,29 @@ class Training(RunEnvironment):
def _set_gen(self, mode):
gen = self.data_store.get("generator", f"general.{mode}")
setattr(self, f"{mode}_generator", Distributor(gen, self.model, self.batch_size))
setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size))
def set_generators(self):
map(lambda mode: self._set_gen(mode), ["train", "val", "test"])
for mode in ["train", "val", "test"]:
self._set_gen(mode)
def train(self):
logging.info(f"Train with {len(self.train_generator)} mini batches.")
history = self.model.fit_generator(generator=self.train_generator.distribute_on_batches(),
steps_per_epoch=len(self.train_generator),
logging.info(f"Train with {len(self.train_set)} mini batches.")
history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
steps_per_epoch=len(self.train_set),
epochs=self.epochs,
verbose=2,
validation_data=self.val_generator.distribute_on_batches(),
validation_steps=len(self.val_generator),
validation_data=self.val_set.distribute_on_batches(),
validation_steps=len(self.val_set),
callbacks=[self.checkpoint, self.lr_sc])
self.save_callbacks(history)
def save_callbacks(self, history):
path = self.data_store.get("experiment_path", "general")
with open(os.path.join(path, "history.json"), "w") as f:
json.dump(history.history, f)
with open(os.path.join(path, "history_lr.json"), "w") as f:
json.dump(self.lr_sc.lr, f)
......@@ -183,9 +183,9 @@ class TestSetExperimentName:
assert exp_name == "TestExperiment"
assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment"))
exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2")
assert exp_name == "2019-11-14_network/"
assert exp_name == "2019-11-14_network"
assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2"))
def test_set_experiment_from_sys(self):
exp_name, _ = set_experiment_name(experiment_date="2019-11-14")
assert exp_name == "2019-11-14_network/"
assert exp_name == "2019-11-14_network"
......@@ -39,8 +39,8 @@ class TestPreProcessing:
with PreProcessing():
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). '
r'Found 5/5 valid stations.'))
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
r'station\(s\). Found 5/5 valid stations.'))
RunEnvironment().__del__()
def test_run(self, obj_with_exp_setup):
......@@ -88,8 +88,8 @@ class TestPreProcessing:
assert len(valid_stations) < len(stations)
assert valid_stations == stations[:-1]
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found '
r'5/6 valid stations.'))
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
r'station\(s\). Found 5/6 valid stations.'))
def test_split_set_indices(self, obj_super_init):
dummy_list = list(range(0, 15))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment