diff --git a/src/modules/training.py b/src/modules/training.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef3138a83f117f558e11022e2d5053a21666364 --- /dev/null +++ b/src/modules/training.py @@ -0,0 +1,19 @@ +__author__ = "Lukas Leufen" +__date__ = '2019-12-05' + + +from src.modules.run_environment import RunEnvironment + + +class Training(RunEnvironment): + + def __init__(self): + super().__init__() + self.model = self.data_store.get("model", "general.model") + + def make_predict_function(self): + # create the predict function before distributing. This is necessary, because tf will compile the predict + # function just in the moment it is used the first time. This can cause problems, if the model is distributed + # on different workers. To prevent this, the function is pre-compiled. See discussion @ + # https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252 + self.model._make_predict_function() \ No newline at end of file