diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py index 4170b20e58e29036d766010a7ffbe341e311d755..ba6457a625ad5d2d8254817c6b8e9a3e93daa403 100644 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py @@ -11,6 +11,7 @@ import os import glob import pickle import numpy as np +import pandas as pd import json import tensorflow as tf from normalization import Norm_data @@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): default: "minmax") """ self.input_dir = input_dir - # ML: No hidden path-extensions (rather managed in generate_runscript.py) - # self.input_dir_pkl = os.path.join(input_dir,"pickle") self.output_dir = dest_dir - # if the output_dir is not exist, then create it + # if the output_dir does not exist, then create it os.makedirs(self.output_dir, exist_ok=True) # get metadata,includes the var_in, image height, width etc. self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json") @@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): # search for folder names from pickle folder to get years patt = "X_*.pkl" for year in self.years: - print("pahtL:", os.path.join(self.input_dir, year, patt)) months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt)) - print("months_pkl_list", months_pkl_list) months_list = [int(m[-6:-4]) for m in months_pkl_list] self.months.extend(months_list) self.years_months.append(months_list) @@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ Get the corresponding statistics file """ - self.stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") - print("Opening json-file: {0}".format(self.stats_file)) - if os.path.isfile(self.stats_file): - with open(self.stats_file) as js_file: + method = ERA5Pkl2Tfrecords.get_stat_file.__name__ + + stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") + print("Opening json-file: {0}".format(stats_file)) + if os.path.isfile(stats_file): + with open(stats_file) as js_file: self.stats = json.load(js_file) else: - raise FileNotFoundError("Statistic file does not exist") + raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file)) def get_metadata(self, md_instance): """ @@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): height : int, the height of the image width : int, the width of the image """ + method = ERA5Pkl2Tfrecords.get_metadata.__name__ + if not isinstance(md_instance, MetaData): - raise ValueError("md_instance-argument must be a MetaData class instance") + raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method)) if not hasattr(self, "metadata_fl"): - raise ValueError("MetaData class instance passed, but attribute metadata_fl is still missing.") + raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method)) try: self.height, self.width = md_instance.ny, md_instance.nx self.vars_in = md_instance.variables except: - raise IOError("Could not retrieve all required information from metadata-file '{0}'" - .format(self.metadata_fl)) + raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'" + .format(method, self.metadata_fl)) @staticmethod def save_tf_record(output_fname, sequences, t_start_points): @@ -114,12 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): t_start_points : datetime type in the list, the first timestamp for each sequence [seq_len,height,width, channel], the len of t_start_points is the same as sequences """ + method = ERA5Pkl2Tfrecords.save_tf_record.__name__ + sequences = np.array(sequences) # sanity checks - print(t_start_points[0]) - print(type(t_start_points[0])) - assert sequences.shape[0] == len(t_start_points) - assert type(t_start_points) == datetime.datetime, "What's that: {0} (type {1})".format(t_start_points[0], type(t_start_points[0])) + assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method) + assert type(t_start_points[0]) == datetime.datetime, "%{0}: Elements of t_start_points must be datetime-objects.".format(method) with tf.python_io.TFRecordWriter(output_fname) as writer: for i in range(len(sequences)): @@ -144,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ Get normalization data class """ - print("Make use of default minmax-normalization...") + method = ERA5Pkl2Tfrecords.init_norm_class.__name__ + + print("%{0}: Make use of default minmax-normalization.".format(method)) # init normalization-instance self.norm_cls = Norm_data(self.vars_in) self.nvars = len(self.vars_in) @@ -162,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): Return: the normalized sequences """ - assert len(np.array(sequences).shape) == 5 + method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__ + + assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method) # normalization should adpot the selected variables, here we used duplicated channel temperature variables sequences = np.array(sequences) # normalization @@ -177,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): year : int, the target year to save to tfrecord month : int, the target month to save to tfrecord """ + method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__ + # Define the input_file based on the year and month self.input_file_year = os.path.join(self.input_dir, str(year)) input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month)) @@ -187,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): t_start_points = [] sequence_iter = 0 - # try: - with open(input_file, "rb") as data_file: - X_train = pickle.load(data_file) - with open(temp_input_file, "rb") as temp_file: - T_train = pickle.load(temp_file) + try: + with open(input_file, "rb") as data_file: + X_train = pickle.load(data_file) + except: + raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file)) + + try: + with open(temp_input_file, "rb") as temp_file: + T_train = pickle.load(temp_file) + except: + raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file)) # check to make sure that X_train and T_train have the same length assert (len(X_train) == len(T_train)) @@ -202,8 +215,6 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): seq = X_train[X_start:X_end, ...] # recording the start point of the timestamps (already datetime-objects) t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0]) - print("t_start,", t_start) - print("type of t_starty", type(t_start)) seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) if not sequences: last_start_sequence_iter = sequence_iter @@ -221,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points) t_start_points = [] sequences = [] - print("Finished for input file", input_file) + print("%{0}: Finished processing of input file '{1}'".format(method, input_file)) # except FileNotFoundError as fnf_error: # print(fnf_error) @@ -232,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): Function to check if the sequences has been processed. If yes, the sequences are skipped, otherwise the sequences are saved to the output file """ + method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__ + if os.path.isfile(output_fname): - print(output_fname, 'already exists, skip it') + print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname)) else: ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points) @@ -252,24 +265,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ Wrapper to return a datetime-object """ + method = ERA5Pkl2Tfrecords.ensure_datetime.__name__ + fmt = "%Y%m%d %H:%M" if isinstance(date, datetime.datetime): date_new = date else: try: date_new=pd.to_datetime(date) - date_new=datetime.datetime(date_new.strptime(fmt), fmt) + date_new=date_new.to_pydatetime() except Exception as err: print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date))) raise err return date_new -# def num_examples_per_epoch(self): -# with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: -# sequence_lengths = sequence_lengths_file.readlines() -# sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] -# return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) - def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))