Skip to content
Snippets Groups Projects
Commit f077a719 authored by gong1's avatar gong1
Browse files

address issue for training dataets numer samples per epoch

parent 0a11fcba
Branches
Tags
No related merge requests found
......@@ -29,8 +29,8 @@ if [ -z ${VIRTUAL_ENV} ]; then
fi
# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/scarlet_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/bing_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
# run Preprocessing (step 2 where Tf-records are generated)
srun python ../video_prediction/datasets/era5_dataset_v2.py ${source_dir}/hickle ${destination_dir}/tfrecords -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20
......@@ -34,9 +34,8 @@ fi
# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py)
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-160x128-2970N1500W-T2_MSL_gph500
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/scarlet_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/bing_era5-Y2017_testM01to12-160x128-2970N1500W-T2_MSL_gph500
# for choosing the model
model=convLSTM
model_hparams=../hparams/era5/${model}/model_hparams.json
......
......
......@@ -21,4 +21,4 @@ module load mpi4py/3.0.1-Python-3.6.8
module load h5py/2.9.0-serial-Python-3.6.8
module load TensorFlow/1.13.1-GPU-Python-3.6.8
module load cuDNN/7.5.1.10-CUDA-10.1.105
module load netcdf4-python/1.5.0.1-Python-3.6.8
{
"batch_size": 10,
"lr": 0.001,
"max_epochs":2,
"context_frames":10,
"sequence_length":20
}
......@@ -61,7 +61,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def filter(self, serialized_example):
return tf.convert_to_tensor(True)
......@@ -325,7 +324,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file,vars_in,year,mon
#sequence_lengths_file.close()
return
def write_sequence_file(output_dir,seq_length):
def write_sequence_file(output_dir,seq_length,sequences_per_file):
partition_names = ["train","val","test"]
for partition_name in partition_names:
......@@ -333,7 +332,7 @@ def write_sequence_file(output_dir,seq_length):
tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
for i in range(tfCounter):
for i in range(tfCounter*sequences_per_file):
sequence_lengths_file.write("%d\n" % seq_length)
sequence_lengths_file.close()
......@@ -349,6 +348,7 @@ def main():
parser.add_argument("-height",type=int,default=64)
parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20)
parser.add_argument("-sequences_per_file",type=int,default=2)
args = parser.parse_args()
current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
......@@ -405,7 +405,7 @@ def main():
message_counter = message_counter + 1
print("Message in from slaver",message_in)
write_sequence_file(args.output_dir,args.seq_length)
write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
#write_sequence_file
else:
......@@ -421,7 +421,7 @@ def main():
input_file = "X_" + '{0:02}'.format(my_rank) + ".pkl"
input_dir = os.path.join(args.input_dir,year)
input_file = os.path.join(input_dir,input_file)
#read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir,input_file=input_file,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2)
read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir,input_file=input_file,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)
print("Year {} finished",year)
message_out = ("Node:",str(my_rank),"finished","","\r\n")
print ("Message out for slaves:",message_out)
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment