diff --git a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh b/video_prediction_tools/HPC_scripts/train_model_era5_template.sh
index fff363affdd39ac4c10542961ab084406ab6ad62..34a4c9dcbc23d4cc7ab1fd363e20b84e4d1ee661 100644
--- a/video_prediction_tools/HPC_scripts/train_model_era5_template.sh
+++ b/video_prediction_tools/HPC_scripts/train_model_era5_template.sh
@@ -50,8 +50,5 @@ dataset=era5
 
 # run training
 srun python ../main_scripts/main_train_models.py --input_dir  ${source_dir} --datasplit_dict ${datasplit_dict} \
- --dataset ${dataset}  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}
-
-
-
+ --dataset ${dataset}  --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir} --checkpoint_dir  ${destination_dir}
  
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index c6ac709d7d5d82487e60ab915bf7b41cd11ffabc..f6fd6e78db9c49f433d679c03c97463a654192d7 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -259,7 +259,7 @@ class TrainModel(object):
         """
         Restore the train and validation losses in the pickle file if checkpoint is given 
         """
-        if self.start_step == 0:
+        if not os.path.exists(os.path.join(self.output_dir,"checkpoint")):
             train_losses = []
             val_losses = []
         else:
@@ -278,7 +278,8 @@ class TrainModel(object):
             print("parameter_count =", sess.run(self.parameter_count))
             sess.run(tf.global_variables_initializer())
             sess.run(tf.local_variables_initializer())
-            self.restore(sess, self.checkpoint)
+            if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
+                self.restore(sess, self.checkpoint)
             #sess.graph.finalize()
             self.start_step = sess.run(self.global_step)
             print("start_step", self.start_step)