diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 707085b6a7a82c9cc16be5118758e1d0b52abd5c..9e58de96a31913eb19678e151fac5c46d6e80409 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -247,6 +247,10 @@ class TrainModel(object): self.steps_per_epoch = int(self.num_examples/self.batch_size) self.total_steps = self.steps_per_epoch * self.max_epochs self.diag_intv_step = int(self.diag_intv_frac*self.total_steps) + if self.diag_intv_step == 0: + self.diag_intv_step = 1 + else: + pass print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch, self.total_steps))