From 7424bcc4a426c25fa4d78329a407b87078efe8eb Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Tue, 18 May 2021 15:02:08 +0200
Subject: [PATCH] check if checkpoint exists

---
 video_prediction_tools/main_scripts/main_train_models.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 069e8ec9..c6226ef7 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -259,7 +259,7 @@ class TrainModel(object):
         """
         Restore the train and validation losses in the pickle file if checkpoint is given 
         """
-        if self.start_step == 0:
+        if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
             train_losses = []
             val_losses = []
         else:
@@ -278,7 +278,8 @@ class TrainModel(object):
             print("parameter_count =", sess.run(self.parameter_count))
             sess.run(tf.global_variables_initializer())
             sess.run(tf.local_variables_initializer())
-            self.restore(sess, self.checkpoint)
+            if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
+                self.restore(sess, self.checkpoint)
             #sess.graph.finalize()
             self.start_step = sess.run(self.global_step)
             print("start_step", self.start_step)
-- 
GitLab