From adcdd8c9eb4475b32bfabb88d428a67771046807 Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Fri, 7 Aug 2020 08:17:11 +0200
Subject: [PATCH] address/correct the number of samples per epoch number issue

---
 .../scripts/generate_transfer_learning_finetune.py          | 1 +
 .../video_prediction/datasets/era5_dataset_v2.py            | 6 +++---
 .../video_prediction/models/vanilla_convLSTM_model.py       | 2 +-
 3 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index 13b93889..3df6f7e2 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -357,6 +357,7 @@ def main():
     sess.graph.as_default()
     sess.run(tf.global_variables_initializer())
     sess.run(tf.local_variables_initializer())
+    model.restore(sess, args.checkpoint)
     
     #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
     sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
index 5baad396..a3c9fc36 100644
--- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
@@ -366,13 +366,13 @@ def main():
                # "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
                # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
                # "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2017":[1,2,3,4,5,6,7,8,9,10]
+                "2017_test":[1,2,3,4,5,6,7,8,9,10]
                  },
             "val":
-                {"2017":[11]
+                {"2017_test":[11]
                  },
             "test":
-                {"2017":[12]
+                {"2017_test":[12]
                  }
             }
     
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index c7f3db7c..7560a225 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -41,7 +41,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
             max_steps: number of training steps.
-            context_frames: the number of ground-truth frames to pass in at
+            context_frames: the number of ground-truth frames to pass :qin at
                 start. Must be specified during instantiation.
             sequence_length: the number of frames in the video sequence,
                 including the context frames, so this model predicts
-- 
GitLab