Skip to content
Snippets Groups Projects
Commit f077a719 authored by gong1's avatar gong1
Browse files

address issue for training dataets numer samples per epoch

parent 0a11fcba
No related branches found
No related tags found
No related merge requests found
...@@ -29,8 +29,8 @@ if [ -z ${VIRTUAL_ENV} ]; then ...@@ -29,8 +29,8 @@ if [ -z ${VIRTUAL_ENV} ]; then
fi fi
# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) # declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500 source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/scarlet_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500 destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/bing_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
# run Preprocessing (step 2 where Tf-records are generated) # run Preprocessing (step 2 where Tf-records are generated)
srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/hickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/hickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20
...@@ -34,9 +34,8 @@ fi ...@@ -34,9 +34,8 @@ fi
# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) # declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500 source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/scarlet_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500 destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/bing_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
# for choosing the model # for choosing the model
model=convLSTM model=convLSTM
model_hparams=../hparams/era5/${model}/model_hparams.json model_hparams=../hparams/era5/${model}/model_hparams.json
......
...@@ -21,4 +21,4 @@ module load mpi4py/3.0.1-Python-3.6.8 ...@@ -21,4 +21,4 @@ module load mpi4py/3.0.1-Python-3.6.8
module load h5py/2.9.0-serial-Python-3.6.8 module load h5py/2.9.0-serial-Python-3.6.8
module load TensorFlow/1.13.1-GPU-Python-3.6.8 module load TensorFlow/1.13.1-GPU-Python-3.6.8
module load cuDNN/7.5.1.10-CUDA-10.1.105 module load cuDNN/7.5.1.10-CUDA-10.1.105
module load netcdf4-python/1.5.0.1-Python-3.6.8
{
"batch_size": 10,
"lr": 0.001,
"max_epochs":2,
"context_frames":10,
"sequence_length":20
}
...@@ -61,7 +61,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset): ...@@ -61,7 +61,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def filter(self, serialized_example): def filter(self, serialized_example):
return tf.convert_to_tensor(True) return tf.convert_to_tensor(True)
...@@ -325,7 +324,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file,vars_in,year,mon ...@@ -325,7 +324,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file,vars_in,year,mon
#sequence_lengths_file.close() #sequence_lengths_file.close()
return return
def write_sequence_file(output_dir,seq_length): def write_sequence_file(output_dir,seq_length,sequences_per_file):
partition_names = ["train","val","test"] partition_names = ["train","val","test"]
for partition_name in partition_names: for partition_name in partition_names:
...@@ -333,7 +332,7 @@ def write_sequence_file(output_dir,seq_length): ...@@ -333,7 +332,7 @@ def write_sequence_file(output_dir,seq_length):
tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
for i in range(tfCounter): for i in range(tfCounter*sequences_per_file):
sequence_lengths_file.write("%d\n" % seq_length) sequence_lengths_file.write("%d\n" % seq_length)
sequence_lengths_file.close() sequence_lengths_file.close()
...@@ -349,6 +348,7 @@ def main(): ...@@ -349,6 +348,7 @@ def main():
parser.add_argument("-height",type=int,default=64) parser.add_argument("-height",type=int,default=64)
parser.add_argument("-width",type = int,default=64) parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20) parser.add_argument("-seq_length",type=int,default=20)
parser.add_argument("-sequences_per_file",type=int,default=2)
args = parser.parse_args() args = parser.parse_args()
current_path = os.getcwd() current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
...@@ -405,7 +405,7 @@ def main(): ...@@ -405,7 +405,7 @@ def main():
message_counter = message_counter + 1 message_counter = message_counter + 1
print("Message in from slaver",message_in) print("Message in from slaver",message_in)
write_sequence_file(args.output_dir,args.seq_length) write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
#write_sequence_file #write_sequence_file
else: else:
...@@ -421,7 +421,7 @@ def main(): ...@@ -421,7 +421,7 @@ def main():
input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl" input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl"
input_dir = os.path.join(args.input_dir,year) input_dir = os.path.join(args.input_dir,year)
input_file = os.path.join(input_dir,input_file) input_file = os.path.join(input_dir,input_file)
#read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir,input_file=input_file,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir,input_file=input_file,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)
print("Year {} finished",year) print("Year {} finished",year)
message_out = ("Node:",str(my_rank),"finished","","\r\n") message_out = ("Node:",str(my_rank),"finished","","\r\n")
print ("Message out for slaves:",message_out) print ("Message out for slaves:",message_out)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment