From adcdd8c9eb4475b32bfabb88d428a67771046807 Mon Sep 17 00:00:00 2001 From: gong1 <b.gong@fz-juelich.de> Date: Fri, 7 Aug 2020 08:17:11 +0200 Subject: [PATCH] address/correct the number of samples per epoch number issue --- .../scripts/generate_transfer_learning_finetune.py | 1 + .../video_prediction/datasets/era5_dataset_v2.py | 6 +++--- .../video_prediction/models/vanilla_convLSTM_model.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 13b93889..3df6f7e2 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -357,6 +357,7 @@ def main(): sess.graph.as_default() sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py index 5baad396..a3c9fc36 100644 --- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py @@ -366,13 +366,13 @@ def main(): # "2012":[1,2,3,4,5,6,7,8,9,10,11,12], # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12], # "2015":[1,2,3,4,5,6,7,8,9,10,11,12], - "2017":[1,2,3,4,5,6,7,8,9,10] + "2017_test":[1,2,3,4,5,6,7,8,9,10] }, "val": - {"2017":[11] + {"2017_test":[11] }, "test": - {"2017":[12] + {"2017_test":[12] } } diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index c7f3db7c..7560a225 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -41,7 +41,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. max_steps: number of training steps. - context_frames: the number of ground-truth frames to pass in at + context_frames: the number of ground-truth frames to pass :qin at start. Must be specified during instantiation. sequence_length: the number of frames in the video sequence, including the context frames, so this model predicts -- GitLab