diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 069e8ec9ea1812d8254454b33f670209d65024d9..c6226ef7d55e53697626fa4118c09c06ed710060 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -259,7 +259,7 @@ class TrainModel(object): """ Restore the train and validation losses in the pickle file if checkpoint is given """ - if self.start_step == 0: + if os.path.exists(os.path.join(self.output_dir,"checkpoint")): train_losses = [] val_losses = [] else: @@ -278,7 +278,8 @@ class TrainModel(object): print("parameter_count =", sess.run(self.parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) - self.restore(sess, self.checkpoint) + if os.path.exists(os.path.join(self.output_dir,"checkpoint")): + self.restore(sess, self.checkpoint) #sess.graph.finalize() self.start_step = sess.run(self.global_step) print("start_step", self.start_step)