diff --git a/Zam347_scripts/DataPreprocess_to_tf.sh b/Zam347_scripts/DataPreprocess_to_tf.sh index 320307c05e5d4789228e123d50abd7cdb5d39c50..608f95348b25c2169ae2963821639de8947c322b 100755 --- a/Zam347_scripts/DataPreprocess_to_tf.sh +++ b/Zam347_scripts/DataPreprocess_to_tf.sh @@ -1,4 +1,4 @@ #!/bin/bash -x -python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 +python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh index 42e28355e6ebe5a99697dc3c57b8b1ad9a859260..0f1b2c8930befb29399d2d943d558ca8d8e412d4 100755 --- a/Zam347_scripts/train_era5.sh +++ b/Zam347_scripts/train_era5.sh @@ -2,5 +2,5 @@ -python ../scripts/train_v2.py --input_dir /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /home/${USER}/models/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/ours_savp +python ../scripts/train_v2.py --input_dir /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5 --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /home/${USER}/models/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/ours_savp #srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py index 9ff970517c18155daea43c4646d8c750a3c5a509..606f32f0bb47a66e1190be5b9585e6ef9b5cf752 100644 --- a/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction/datasets/era5_dataset_v2.py @@ -27,8 +27,9 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): example = next(tf.python_io.tf_record_iterator(self.filenames[0])) dict_message = MessageToDict(tf.train.Example.FromString(example)) feature = dict_message['features']['feature'] - image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) - self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape + self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) + self.image_shape = self.video_shape[1:] + self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape def get_default_hparams_dict(self): default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict() @@ -64,17 +65,23 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): def parser(serialized_example): seqs = OrderedDict() keys_to_features = { - # 'width': tf.FixedLenFeature([], tf.int64), - # 'height': tf.FixedLenFeature([], tf.int64), + 'width': tf.FixedLenFeature([], tf.int64), + 'height': tf.FixedLenFeature([], tf.int64), 'sequence_length': tf.FixedLenFeature([], tf.int64), - # 'channels': tf.FixedLenFeature([],tf.int64), + 'channels': tf.FixedLenFeature([],tf.int64), # 'images/encoded': tf.FixedLenFeature([], tf.string) 'images/encoded': tf.VarLenFeature(tf.float32) } + # for i in range(20): # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) parsed_features = tf.parse_single_example(serialized_example, keys_to_features) + print ("Parse features", parsed_features) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) + # width = tf.sparse_tensor_to_dense(parsed_features["width"]) + # height = tf.sparse_tensor_to_dense(parsed_features["height"]) + # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) + # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) images = [] # for i in range(20): # images.append(parsed_features["images/encoded"].values[i]) @@ -85,8 +92,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) # images = seq - images = tf.reshape(seq, [20, 128, 160, 3], name = "reshape_new") - print("IMAGES", images) + print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) + images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") seqs["images"] = images return seqs filenames = self.filenames @@ -154,7 +161,7 @@ def save_tf_record(output_fname, sequences): example = tf.train.Example(features=features) writer.write(example.SerializeToString()) -def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,N_seq,sequences_per_file=128,**kwargs):#Bing: original 128 +def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128 # ML 2020/04/08: # Include vars_in for more flexible data handling (normalization and reshaping) # and optional keyword argument for kind of normalization @@ -189,15 +196,15 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, sequence_iter = 0 sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) - X_possible_starts = [i for i in range(len(X_train) - N_seq)] + X_possible_starts = [i for i in range(len(X_train) - seq_length)] for X_start in X_possible_starts: print("Interation", sequence_iter) - X_end = X_start + N_seq + X_end = X_start + seq_length #seq = X_train[X_start:X_end, :, :,:] seq = X_train[X_start:X_end,:,:] #print("*****len of seq ***.{}".format(len(seq))) #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3))) - seq = list(np.array(seq).reshape((len(seq), 128, 160,nvars))) + seq = list(np.array(seq).reshape((seq_length, height, width, nvars))) if not sequences: last_start_sequence_iter = sequence_iter print("reading sequences starting at sequence %d" % sequence_iter) @@ -230,8 +237,9 @@ def main(): # ML 2020/04/08 S # Add vars for ensuring proper normalization and reshaping of sequences parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.") - # parser.add_argument("image_size_h", type=int) - # parser.add_argument("image_size_v", type = int) + parser.add_argument("-height",type=int,default=64) + parser.add_argument("-width",type = int,default=64) + parser.add_argument("-seq_length",type=int,default=20) args = parser.parse_args() current_path = os.getcwd() #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" @@ -239,7 +247,7 @@ def main(): partition_names = ['train','val', 'test'] #64,64,3 val has issue# for partition_name in partition_names: - read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, N_seq=20, sequences_per_file=2) #Bing: Todo need check the N_seq + read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 if __name__ == '__main__':