Skip to content
Snippets Groups Projects
Commit ea4037bf authored by b.gong's avatar b.gong
Browse files

address the hard coding part for Training

parent 08bb6f67
Branches
Tags
No related merge requests found
#!/bin/bash -x #!/bin/bash -x
python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20
...@@ -27,8 +27,9 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): ...@@ -27,8 +27,9 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
example = next(tf.python_io.tf_record_iterator(self.filenames[0])) example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
dict_message = MessageToDict(tf.train.Example.FromString(example)) dict_message = MessageToDict(tf.train.Example.FromString(example))
feature = dict_message['features']['feature'] feature = dict_message['features']['feature']
image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape self.image_shape = self.video_shape[1:]
self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict() default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
...@@ -64,17 +65,23 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): ...@@ -64,17 +65,23 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
def parser(serialized_example): def parser(serialized_example):
seqs = OrderedDict() seqs = OrderedDict()
keys_to_features = { keys_to_features = {
# 'width': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64),
# 'height': tf.FixedLenFeature([], tf.int64), 'height': tf.FixedLenFeature([], tf.int64),
'sequence_length': tf.FixedLenFeature([], tf.int64), 'sequence_length': tf.FixedLenFeature([], tf.int64),
# 'channels': tf.FixedLenFeature([],tf.int64), 'channels': tf.FixedLenFeature([],tf.int64),
# 'images/encoded': tf.FixedLenFeature([], tf.string) # 'images/encoded': tf.FixedLenFeature([], tf.string)
'images/encoded': tf.VarLenFeature(tf.float32) 'images/encoded': tf.VarLenFeature(tf.float32)
} }
# for i in range(20): # for i in range(20):
# keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
parsed_features = tf.parse_single_example(serialized_example, keys_to_features) parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
print ("Parse features", parsed_features)
seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
# width = tf.sparse_tensor_to_dense(parsed_features["width"])
# height = tf.sparse_tensor_to_dense(parsed_features["height"])
# channels = tf.sparse_tensor_to_dense(parsed_features["channels"])
# sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
images = [] images = []
# for i in range(20): # for i in range(20):
# images.append(parsed_features["images/encoded"].values[i]) # images.append(parsed_features["images/encoded"].values[i])
...@@ -85,8 +92,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): ...@@ -85,8 +92,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
# images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
# images = seq # images = seq
images = tf.reshape(seq, [20, 128, 160, 3], name = "reshape_new") print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
print("IMAGES", images) images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
seqs["images"] = images seqs["images"] = images
return seqs return seqs
filenames = self.filenames filenames = self.filenames
...@@ -154,7 +161,7 @@ def save_tf_record(output_fname, sequences): ...@@ -154,7 +161,7 @@ def save_tf_record(output_fname, sequences):
example = tf.train.Example(features=features) example = tf.train.Example(features=features)
writer.write(example.SerializeToString()) writer.write(example.SerializeToString())
def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,N_seq,sequences_per_file=128,**kwargs):#Bing: original 128 def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
# ML 2020/04/08: # ML 2020/04/08:
# Include vars_in for more flexible data handling (normalization and reshaping) # Include vars_in for more flexible data handling (normalization and reshaping)
# and optional keyword argument for kind of normalization # and optional keyword argument for kind of normalization
...@@ -189,15 +196,15 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in, ...@@ -189,15 +196,15 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
sequence_iter = 0 sequence_iter = 0
sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
X_possible_starts = [i for i in range(len(X_train) - N_seq)] X_possible_starts = [i for i in range(len(X_train) - seq_length)]
for X_start in X_possible_starts: for X_start in X_possible_starts:
print("Interation", sequence_iter) print("Interation", sequence_iter)
X_end = X_start + N_seq X_end = X_start + seq_length
#seq = X_train[X_start:X_end, :, :,:] #seq = X_train[X_start:X_end, :, :,:]
seq = X_train[X_start:X_end,:,:] seq = X_train[X_start:X_end,:,:]
#print("*****len of seq ***.{}".format(len(seq))) #print("*****len of seq ***.{}".format(len(seq)))
#seq = list(np.array(seq).reshape((len(seq), 64, 64, 3))) #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
seq = list(np.array(seq).reshape((len(seq), 128, 160,nvars))) seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
if not sequences: if not sequences:
last_start_sequence_iter = sequence_iter last_start_sequence_iter = sequence_iter
print("reading sequences starting at sequence %d" % sequence_iter) print("reading sequences starting at sequence %d" % sequence_iter)
...@@ -230,8 +237,9 @@ def main(): ...@@ -230,8 +237,9 @@ def main():
# ML 2020/04/08 S # ML 2020/04/08 S
# Add vars for ensuring proper normalization and reshaping of sequences # Add vars for ensuring proper normalization and reshaping of sequences
parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.") parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
# parser.add_argument("image_size_h", type=int) parser.add_argument("-height",type=int,default=64)
# parser.add_argument("image_size_v", type = int) parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20)
args = parser.parse_args() args = parser.parse_args()
current_path = os.getcwd() current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
...@@ -239,7 +247,7 @@ def main(): ...@@ -239,7 +247,7 @@ def main():
partition_names = ['train','val', 'test'] #64,64,3 val has issue# partition_names = ['train','val', 'test'] #64,64,3 val has issue#
for partition_name in partition_names: for partition_name in partition_names:
read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, N_seq=20, sequences_per_file=2) #Bing: Todo need check the N_seq read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
#ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment