diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py index 51894163489456ab72d330dc40e4af6936be6e36..8d934bc5583e784a89d15fca8d8158e4c9341579 100644 --- a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py +++ b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py @@ -17,17 +17,18 @@ import warnings def main(): parser = argparse.ArgumentParser() - parser.add_argument("-input_dir", type=str) + parser.add_argument("-source_dir", type=str) parser.add_argument("-dest_dir", type=str) parser.add_argument("-sequence_length", type=int, default=20) parser.add_argument("-sequences_per_file", type=int, default=20) args = parser.parse_args() - ins = ERA5Pkl2Tfrecords(input_dir=args.input_dir, + input_dir = args.source_dir + ins = ERA5Pkl2Tfrecords(input_dir=input_dir, + dest_dir=args.dest_dir, sequence_length = args.sequence_length, sequences_per_file=args.sequences_per_file) years, months,years_months = ins.get_years_months() - input_dir_pkl = os.path.join(args.input_dir, "pickle") # ini. MPI comm = MPI.COMM_WORLD my_rank = comm.Get_rank() # rank of the node @@ -36,7 +37,7 @@ def main(): if my_rank == 0: # retrieve final statistics first (not parallelized!) # some preparatory steps - stat_dir_prefix = input_dir_pkl + stat_dir = os.path.dirname(input_dir) varnames = ins.vars_in vars_uni, varsind, nvars = get_unique_vars(varnames) @@ -46,7 +47,7 @@ def main(): print("Start collecting statistics from the whole dataset to be processed...") for year in years: - file_dir = os.path.join(stat_dir_prefix, year) + file_dir = os.path.join(input_dir, year) for month in months: if os.path.isfile(os.path.join(file_dir, "stat_" + '{0:02}'.format(month) + ".json")): # process stat-file: @@ -55,7 +56,7 @@ def main(): warnings.warn("The stat file for year {} month {} does not exist".format(year, month)) # finalize statistics and write to json-file stat_obj.finalize_stat_master(vars_uni) - stat_obj.write_stat_json(input_dir_pkl) + stat_obj.write_stat_json(stat_dir) # organize parallelized partioning real_years_months = [] @@ -90,7 +91,7 @@ def main(): year_rank = "Y_{}_M_{}".format(year, my_rank) if year_rank in real_years_months: # Initilial instance - ins2 = ERA5Pkl2Tfrecords(input_dir=args.input_dir, + ins2 = ERA5Pkl2Tfrecords(input_dir=input_dir, dest_dir=args.dest_dir, sequence_length = args.sequence_length, sequences_per_file=args.sequences_per_file)