diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 7eb1cd7ac93ad7ea438a738bcf2ab5c1dd6397a2..ff2cffcdf01fd9e917bf1120984c6b65e1f5a13d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -28,6 +28,7 @@ class Training(RunEnvironment): self.lr_sc = self.data_store.get("lr_decay", "general.model") self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") + self._trainable = self.data_store.get("trainable", "general") self._run() def _run(self) -> None: @@ -44,8 +45,11 @@ class Training(RunEnvironment): """ self.set_generators() self.make_predict_function() - self.train() - self.save_model() + if self._trainable: + self.train() + self.save_model() + else: + logging.info("No training has started, because trainable parameter was false.") def make_predict_function(self) -> None: """