diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 13b93889875779942e5171e5e1d98eebc84fd9f3..3df6f7e2843eb732df4d0be70f410853a0ac2a78 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -357,6 +357,7 @@ def main(): sess.graph.as_default() sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py index 70797c34344b82ba980031ecc08eeb999861f396..a3c9fc3666eb21ef02f9e5f64c8c95a29034d619 100644 --- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py @@ -62,7 +62,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) - def filter(self, serialized_example): return tf.convert_to_tensor(True) @@ -326,7 +325,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file #sequence_lengths_file.close() return -def write_sequence_file(output_dir,seq_length): +def write_sequence_file(output_dir,seq_length,sequences_per_file): partition_names = ["train","val","test"] for partition_name in partition_names: @@ -334,7 +333,7 @@ def write_sequence_file(output_dir,seq_length): tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') - for i in range(tfCounter): + for i in range(tfCounter*sequences_per_file): sequence_lengths_file.write("%d\n" % seq_length) sequence_lengths_file.close() @@ -350,6 +349,7 @@ def main(): parser.add_argument("-height",type=int,default=64) parser.add_argument("-width",type = int,default=64) parser.add_argument("-seq_length",type=int,default=20) + parser.add_argument("-sequences_per_file",type=int,default=2) args = parser.parse_args() current_path = os.getcwd() #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" @@ -366,13 +366,13 @@ def main(): # "2012":[1,2,3,4,5,6,7,8,9,10,11,12], # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12], # "2015":[1,2,3,4,5,6,7,8,9,10,11,12], - "2017":[1,2,3,4,5,6,7,8,9,10] + "2017_test":[1,2,3,4,5,6,7,8,9,10] }, "val": - {"2017":[11] + {"2017_test":[11] }, "test": - {"2017":[12] + {"2017_test":[12] } } @@ -424,7 +424,8 @@ def main(): message_counter = message_counter + 1 print("Message in from slaver",message_in) - write_sequence_file(args.output_dir,args.seq_length) + write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file) + #write_sequence_file else: message_in = comm.recv() @@ -449,7 +450,7 @@ def main(): read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \ input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \ partition_name=partition_name,seq_length=args.seq_length, \ - height=args.height,width=args.width,sequences_per_file=20) + height=args.height,width=args.width,sequences_per_file=args.sequences_per_file) print("Year {} finished",year) message_out = ("Node:",str(my_rank),"finished","","\r\n") diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index c7f3db7ce4fce732312eba0d9f17362faa2e64b5..7560a225e7651728e2ca8d2107d7f32458106c86 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -41,7 +41,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. max_steps: number of training steps. - context_frames: the number of ground-truth frames to pass in at + context_frames: the number of ground-truth frames to pass :qin at start. Must be specified during instantiation. sequence_length: the number of frames in the video sequence, including the context frames, so this model predicts