Skip to content
Snippets Groups Projects
Commit d81e3f2f authored by Michael Langguth's avatar Michael Langguth
Browse files

Relocating statistic-json and several fixes.

parent b26da757
Branches
Tags
No related merge requests found
Pipeline #59740 passed
......@@ -17,17 +17,18 @@ import warnings
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-input_dir", type=str)
parser.add_argument("-source_dir", type=str)
parser.add_argument("-dest_dir", type=str)
parser.add_argument("-sequence_length", type=int, default=20)
parser.add_argument("-sequences_per_file", type=int, default=20)
args = parser.parse_args()
ins = ERA5Pkl2Tfrecords(input_dir=args.input_dir,
input_dir = args.source_dir
ins = ERA5Pkl2Tfrecords(input_dir=input_dir,
dest_dir=args.dest_dir,
sequence_length = args.sequence_length,
sequences_per_file=args.sequences_per_file)
years, months,years_months = ins.get_years_months()
input_dir_pkl = os.path.join(args.input_dir, "pickle")
# ini. MPI
comm = MPI.COMM_WORLD
my_rank = comm.Get_rank() # rank of the node
......@@ -36,7 +37,7 @@ def main():
if my_rank == 0:
# retrieve final statistics first (not parallelized!)
# some preparatory steps
stat_dir_prefix = input_dir_pkl
stat_dir = os.path.dirname(input_dir)
varnames = ins.vars_in
vars_uni, varsind, nvars = get_unique_vars(varnames)
......@@ -46,7 +47,7 @@ def main():
print("Start collecting statistics from the whole dataset to be processed...")
for year in years:
file_dir = os.path.join(stat_dir_prefix, year)
file_dir = os.path.join(input_dir, year)
for month in months:
if os.path.isfile(os.path.join(file_dir, "stat_" + '{0:02}'.format(month) + ".json")):
# process stat-file:
......@@ -55,7 +56,7 @@ def main():
warnings.warn("The stat file for year {} month {} does not exist".format(year, month))
# finalize statistics and write to json-file
stat_obj.finalize_stat_master(vars_uni)
stat_obj.write_stat_json(input_dir_pkl)
stat_obj.write_stat_json(stat_dir)
# organize parallelized partioning
real_years_months = []
......@@ -90,7 +91,7 @@ def main():
year_rank = "Y_{}_M_{}".format(year, my_rank)
if year_rank in real_years_months:
# Initilial instance
ins2 = ERA5Pkl2Tfrecords(input_dir=args.input_dir,
ins2 = ERA5Pkl2Tfrecords(input_dir=input_dir,
dest_dir=args.dest_dir,
sequence_length = args.sequence_length,
sequences_per_file=args.sequences_per_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment