Skip to content
Snippets Groups Projects
Commit ea4037bf authored by b.gong's avatar b.gong
Browse files

address the hard coding part for Training

parent 08bb6f67
Branches
No related tags found
No related merge requests found
#!/bin/bash -x
python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500
python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20
......@@ -27,8 +27,9 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
dict_message = MessageToDict(tf.train.Example.FromString(example))
feature = dict_message['features']['feature']
image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
self.image_shape = self.video_shape[1:]
self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
def get_default_hparams_dict(self):
default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
......@@ -64,17 +65,23 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
def parser(serialized_example):
seqs = OrderedDict()
keys_to_features = {
# 'width': tf.FixedLenFeature([], tf.int64),
# 'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'sequence_length': tf.FixedLenFeature([], tf.int64),
# 'channels': tf.FixedLenFeature([],tf.int64),
'channels': tf.FixedLenFeature([],tf.int64),
# 'images/encoded': tf.FixedLenFeature([], tf.string)
'images/encoded': tf.VarLenFeature(tf.float32)
}
# for i in range(20):
# keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
print ("Parse features", parsed_features)
seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
# width = tf.sparse_tensor_to_dense(parsed_features["width"])
# height = tf.sparse_tensor_to_dense(parsed_features["height"])
# channels = tf.sparse_tensor_to_dense(parsed_features["channels"])
# sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
images = []
# for i in range(20):
# images.append(parsed_features["images/encoded"].values[i])
......@@ -85,8 +92,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
# images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
# images = seq
images = tf.reshape(seq, [20, 128, 160, 3], name = "reshape_new")
print("IMAGES", images)
print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
seqs["images"] = images
return seqs
filenames = self.filenames
......@@ -154,7 +161,7 @@ def save_tf_record(output_fname, sequences):
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,N_seq,sequences_per_file=128,**kwargs):#Bing: original 128
def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
# ML 2020/04/08:
# Include vars_in for more flexible data handling (normalization and reshaping)
# and optional keyword argument for kind of normalization
......@@ -189,15 +196,15 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
sequence_iter = 0
sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
X_possible_starts = [i for i in range(len(X_train) - N_seq)]
X_possible_starts = [i for i in range(len(X_train) - seq_length)]
for X_start in X_possible_starts:
print("Interation", sequence_iter)
X_end = X_start + N_seq
X_end = X_start + seq_length
#seq = X_train[X_start:X_end, :, :,:]
seq = X_train[X_start:X_end,:,:]
#print("*****len of seq ***.{}".format(len(seq)))
#seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
seq = list(np.array(seq).reshape((len(seq), 128, 160,nvars)))
seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
if not sequences:
last_start_sequence_iter = sequence_iter
print("reading sequences starting at sequence %d" % sequence_iter)
......@@ -230,8 +237,9 @@ def main():
# ML 2020/04/08 S
# Add vars for ensuring proper normalization and reshaping of sequences
parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
# parser.add_argument("image_size_h", type=int)
# parser.add_argument("image_size_v", type = int)
parser.add_argument("-height",type=int,default=64)
parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20)
args = parser.parse_args()
current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
......@@ -239,7 +247,7 @@ def main():
partition_names = ['train','val', 'test'] #64,64,3 val has issue#
for partition_name in partition_names:
read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, N_seq=20, sequences_per_file=2) #Bing: Todo need check the N_seq
read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
#ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment