Skip to content
Snippets Groups Projects
Commit e8309fe5 authored by stadtler1's avatar stadtler1
Browse files

Merge branch 'bing_issue#010_remove_hickle_split_data' of...

Merge branch 'bing_issue#010_remove_hickle_split_data' of https://gitlab.version.fz-juelich.de/toar/ambs into bing_issue#010_remove_hickle_split_data
parents 1033de25 adcdd8c9
Branches
Tags
No related merge requests found
Pipeline #42517 failed
......@@ -357,6 +357,7 @@ def main():
sess.graph.as_default()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
model.restore(sess, args.checkpoint)
#model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
......
......@@ -62,7 +62,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def filter(self, serialized_example):
return tf.convert_to_tensor(True)
......@@ -326,7 +325,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file
#sequence_lengths_file.close()
return
def write_sequence_file(output_dir,seq_length):
def write_sequence_file(output_dir,seq_length,sequences_per_file):
partition_names = ["train","val","test"]
for partition_name in partition_names:
......@@ -334,7 +333,7 @@ def write_sequence_file(output_dir,seq_length):
tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
for i in range(tfCounter):
for i in range(tfCounter*sequences_per_file):
sequence_lengths_file.write("%d\n" % seq_length)
sequence_lengths_file.close()
......@@ -350,6 +349,7 @@ def main():
parser.add_argument("-height",type=int,default=64)
parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20)
parser.add_argument("-sequences_per_file",type=int,default=2)
args = parser.parse_args()
current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
......@@ -366,13 +366,13 @@ def main():
# "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
# "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
# "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
"2017":[1,2,3,4,5,6,7,8,9,10]
"2017_test":[1,2,3,4,5,6,7,8,9,10]
},
"val":
{"2017":[11]
{"2017_test":[11]
},
"test":
{"2017":[12]
{"2017_test":[12]
}
}
......@@ -424,7 +424,8 @@ def main():
message_counter = message_counter + 1
print("Message in from slaver",message_in)
write_sequence_file(args.output_dir,args.seq_length)
write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
#write_sequence_file
else:
message_in = comm.recv()
......@@ -449,7 +450,7 @@ def main():
read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \
input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \
partition_name=partition_name,seq_length=args.seq_length, \
height=args.height,width=args.width,sequences_per_file=20)
height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)
print("Year {} finished",year)
message_out = ("Node:",str(my_rank),"finished","","\r\n")
......
......@@ -41,7 +41,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step.
max_steps: number of training steps.
context_frames: the number of ground-truth frames to pass in at
context_frames: the number of ground-truth frames to pass :qin at
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment