diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py index c7df46397291288ff7f6c502158abd0b59889cfc..7cf1931577ce6726c6bce2baa23b8a7934b50a6e 100644 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py @@ -11,6 +11,7 @@ import os import glob import pickle import numpy as np +import pandas as pd import json import tensorflow as tf from normalization import Norm_data @@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): default: "minmax") """ self.input_dir = input_dir - # ML: No hidden path-extensions (rather managed in generate_runscript.py) - # self.input_dir_pkl = os.path.join(input_dir,"pickle") self.output_dir = dest_dir - # if the output_dir is not exist, then create it + # if the output_dir does not exist, then create it os.makedirs(self.output_dir, exist_ok=True) # get metadata,includes the var_in, image height, width etc. self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json") @@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): # search for folder names from pickle folder to get years patt = "X_*.pkl" for year in self.years: - print("pahtL:", os.path.join(self.input_dir, year, patt)) months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt)) - print("months_pkl_list", months_pkl_list) months_list = [int(m[-6:-4]) for m in months_pkl_list] self.months.extend(months_list) self.years_months.append(months_list) @@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ Get the corresponding statistics file """ - self.stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") - print("Opening json-file: {0}".format(self.stats_file)) - if os.path.isfile(self.stats_file): - with open(self.stats_file) as js_file: + method = ERA5Pkl2Tfrecords.get_stats_file.__name__ + + stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") + print("Opening json-file: {0}".format(stats_file)) + if os.path.isfile(stats_file): + with open(stats_file) as js_file: self.stats = json.load(js_file) else: - raise FileNotFoundError("Statistic file does not exist") + raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file)) def get_metadata(self, md_instance): """ @@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): height : int, the height of the image width : int, the width of the image """ + method = ERA5Pkl2Tfrecords.get_metadata.__name__ + if not isinstance(md_instance, MetaData): - raise ValueError("md_instance-argument must be a MetaData class instance") + raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method)) if not hasattr(self, "metadata_fl"): - raise ValueError("MetaData class instance passed, but attribute metadata_fl is still missing.") + raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method)) try: self.height, self.width = md_instance.ny, md_instance.nx self.vars_in = md_instance.variables except: - raise IOError("Could not retrieve all required information from metadata-file '{0}'" - .format(self.metadata_fl)) + raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'" + .format(method, self.metadata_fl)) @staticmethod def save_tf_record(output_fname, sequences, t_start_points): @@ -114,10 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): t_start_points : datetime type in the list, the first timestamp for each sequence [seq_len,height,width, channel], the len of t_start_points is the same as sequences """ + method = ERA5Pkl2Tfrecords.save_tf_record.__name__ + sequences = np.array(sequences) # sanity checks - assert sequences.shape[0] == len(t_start_points) - assert type(t_start_points[0]) == datetime.datetime + assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method) + assert isinstance(t_start_points[0], datetime.datetime), "%{0}: Elements of t_start_points must be datetime-objects.".format(method) with tf.python_io.TFRecordWriter(output_fname) as writer: for i in range(len(sequences)): @@ -142,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ Get normalization data class """ - print("Make use of default minmax-normalization...") + method = ERA5Pkl2Tfrecords.init_norm_class.__name__ + + print("%{0}: Make use of default minmax-normalization.".format(method)) # init normalization-instance self.norm_cls = Norm_data(self.vars_in) self.nvars = len(self.vars_in) @@ -160,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): Return: the normalized sequences """ - assert len(np.array(sequences).shape) == 5 + method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__ + + assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method) # normalization should adpot the selected variables, here we used duplicated channel temperature variables sequences = np.array(sequences) # normalization @@ -175,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): year : int, the target year to save to tfrecord month : int, the target month to save to tfrecord """ + method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__ + # Define the input_file based on the year and month self.input_file_year = os.path.join(self.input_dir, str(year)) input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month)) @@ -185,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): t_start_points = [] sequence_iter = 0 - # try: - with open(input_file, "rb") as data_file: - X_train = pickle.load(data_file) - with open(temp_input_file, "rb") as temp_file: - T_train = pickle.load(temp_file) + try: + with open(input_file, "rb") as data_file: + X_train = pickle.load(data_file) + except: + raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file)) + + try: + with open(temp_input_file, "rb") as temp_file: + T_train = pickle.load(temp_file) + except: + raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file)) # check to make sure that X_train and T_train have the same length assert (len(X_train) == len(T_train)) @@ -199,9 +214,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): X_end = X_start + self.sequence_length seq = X_train[X_start:X_end, ...] # recording the start point of the timestamps (already datetime-objects) - t_start = T_train[X_start] - print("t_start,", t_start) - print("type of t_starty", type(t_start)) + t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0]) seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) if not sequences: last_start_sequence_iter = sequence_iter @@ -219,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points) t_start_points = [] sequences = [] - print("Finished for input file", input_file) + print("%{0}: Finished processing of input file '{1}'".format(method, input_file)) # except FileNotFoundError as fnf_error: # print(fnf_error) @@ -230,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): Function to check if the sequences has been processed. If yes, the sequences are skipped, otherwise the sequences are saved to the output file """ + method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__ + if os.path.isfile(output_fname): - print(output_fname, 'already exists, skip it') + print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname)) else: ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points) @@ -244,12 +259,26 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: seq_file.write("%d\n" % self.sequences_per_file) -# def num_examples_per_epoch(self): -# with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: -# sequence_lengths = sequence_lengths_file.readlines() -# sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] -# return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + @staticmethod + def ensure_datetime(date): + """ + Wrapper to return a datetime-object + """ + method = ERA5Pkl2Tfrecords.ensure_datetime.__name__ + + fmt = "%Y%m%d %H:%M" + if isinstance(date, datetime.datetime): + date_new = date + else: + try: + date_new=pd.to_datetime(date) + date_new=date_new.to_pydatetime() + except Exception as err: + print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date))) + raise err + + return date_new def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py index 8d934bc5583e784a89d15fca8d8158e4c9341579..1a16e4e723bda40baafdfada3f9012b51b61016a 100644 --- a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py +++ b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py @@ -16,6 +16,9 @@ import warnings def main(): + + method="main_preprocess_data_step2" + parser = argparse.ArgumentParser() parser.add_argument("-source_dir", type=str) parser.add_argument("-dest_dir", type=str) @@ -25,26 +28,28 @@ def main(): input_dir = args.source_dir ins = ERA5Pkl2Tfrecords(input_dir=input_dir, dest_dir=args.dest_dir, - sequence_length = args.sequence_length, - sequences_per_file=args.sequences_per_file) + sequence_length = args.sequence_length, + sequences_per_file=args.sequences_per_file) years, months,years_months = ins.get_years_months() # ini. MPI comm = MPI.COMM_WORLD my_rank = comm.Get_rank() # rank of the node - p = comm.Get_size() # number of assigned nods + p = comm.Get_size() # number of assigned nodes + if p < 2: + raise ValueError("%{0}: Preprocessing step 2 must be assigned to at least two tasks.".format(method)) if my_rank == 0: # retrieve final statistics first (not parallelized!) # some preparatory steps stat_dir = os.path.dirname(input_dir) - varnames = ins.vars_in + varnames = ins.vars_in vars_uni, varsind, nvars = get_unique_vars(varnames) stat_obj = Calc_data_stat(nvars) # init statistic-instance # loop over whole data set (training, dev and test set) to collect the intermediate statistics - print("Start collecting statistics from the whole dataset to be processed...") + print("%{0}: Start collecting statistics from the whole dataset to be processed...".format(method)) for year in years: file_dir = os.path.join(input_dir, year) @@ -53,7 +58,7 @@ def main(): # process stat-file: stat_obj.acc_stat_master(file_dir, int(month)) # process monthly statistic-file else: - warnings.warn("The stat file for year {} month {} does not exist".format(year, month)) + warnings.warn("%{0}: The statistic file for year {1}, month {2} does not exist".format(method, year, month)) # finalize statistics and write to json-file stat_obj.finalize_stat_master(vars_uni) stat_obj.write_stat_json(stat_dir) @@ -62,9 +67,7 @@ def main(): real_years_months = [] for i in range(len(years)): year = years[i] - print("I am here year:", year) for month in years_months[i]: - print("I am here month", month) year_month = "Y_{}_M_{}".format(year, month) real_years_months.append(year_month) @@ -74,15 +77,14 @@ def main(): comm.send(broadcast_lists, dest=nodes) message_counter = 1 - while message_counter <= 12: + while message_counter <= p-1: message_in = comm.recv() message_counter = message_counter + 1 - print("Message in from slave: ", message_in) + print("%{0}: Message in from worker: {1} ".format(method, message_in)) else: message_in = comm.recv() - print("My rank,", my_rank) - print("message_in", message_in) + print("%{0}: Message from master to rank {1}: {2} ".format(method, my_rank, message_in)) years = list(message_in[0]) real_years_months = message_in[1] @@ -97,11 +99,11 @@ def main(): sequences_per_file=args.sequences_per_file) # create the tfrecords-files ins2.read_pkl_and_save_tfrecords(year=year, month=my_rank) - print("Year {} finished", year) + print("%{0}: Year {1} finished".format(method, year)) else: - print(year_rank + " is not in the datasplit_dic, will skip the process") + print("%{0}: {1} is not in the datasplit_dic, will skip the process".format(method, year_rank)) message_out = ("Node:", str(my_rank), "finished", "", "\r\n") - print("Message out for slaves:", message_out) + print("%{0}: Message out for worker: {1}".format(method, message_out)) comm.send(message_out, dest=0) MPI.Finalize() diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 30a3286d543ac8ec5f165e9266745e4c2d9732f8..1d3f35dac16c5fde4ba03cad033891fef952e93e 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -139,12 +139,13 @@ class TrainModel(object): Setup train and val dataset instance with the corresponding data split configuration. Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ + VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) # ML/BG 2021-06-15: Is the following needed? - # self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """