Skip to content
Snippets Groups Projects
Commit 5c0f3e79 authored by Michael Langguth's avatar Michael Langguth
Browse files

Corrected handling of t_start_points in preprocess_data_step2.py and source-code style changes.

parent ff258b1d
Branches
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ import os ...@@ -11,6 +11,7 @@ import os
import glob import glob
import pickle import pickle
import numpy as np import numpy as np
import pandas as pd
import json import json
import tensorflow as tf import tensorflow as tf
from normalization import Norm_data from normalization import Norm_data
...@@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
default: "minmax") default: "minmax")
""" """
self.input_dir = input_dir self.input_dir = input_dir
# ML: No hidden path-extensions (rather managed in generate_runscript.py)
# self.input_dir_pkl = os.path.join(input_dir,"pickle")
self.output_dir = dest_dir self.output_dir = dest_dir
# if the output_dir is not exist, then create it # if the output_dir does not exist, then create it
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
# get metadata,includes the var_in, image height, width etc. # get metadata,includes the var_in, image height, width etc.
self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json") self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json")
...@@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
# search for folder names from pickle folder to get years # search for folder names from pickle folder to get years
patt = "X_*.pkl" patt = "X_*.pkl"
for year in self.years: for year in self.years:
print("pahtL:", os.path.join(self.input_dir, year, patt))
months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt)) months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt))
print("months_pkl_list", months_pkl_list)
months_list = [int(m[-6:-4]) for m in months_pkl_list] months_list = [int(m[-6:-4]) for m in months_pkl_list]
self.months.extend(months_list) self.months.extend(months_list)
self.years_months.append(months_list) self.years_months.append(months_list)
...@@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
""" """
Get the corresponding statistics file Get the corresponding statistics file
""" """
self.stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") method = ERA5Pkl2Tfrecords.get_stat_file.__name__
print("Opening json-file: {0}".format(self.stats_file))
if os.path.isfile(self.stats_file): stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json")
with open(self.stats_file) as js_file: print("Opening json-file: {0}".format(stats_file))
if os.path.isfile(stats_file):
with open(stats_file) as js_file:
self.stats = json.load(js_file) self.stats = json.load(js_file)
else: else:
raise FileNotFoundError("Statistic file does not exist") raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file))
def get_metadata(self, md_instance): def get_metadata(self, md_instance):
""" """
...@@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
height : int, the height of the image height : int, the height of the image
width : int, the width of the image width : int, the width of the image
""" """
method = ERA5Pkl2Tfrecords.get_metadata.__name__
if not isinstance(md_instance, MetaData): if not isinstance(md_instance, MetaData):
raise ValueError("md_instance-argument must be a MetaData class instance") raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method))
if not hasattr(self, "metadata_fl"): if not hasattr(self, "metadata_fl"):
raise ValueError("MetaData class instance passed, but attribute metadata_fl is still missing.") raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method))
try: try:
self.height, self.width = md_instance.ny, md_instance.nx self.height, self.width = md_instance.ny, md_instance.nx
self.vars_in = md_instance.variables self.vars_in = md_instance.variables
except: except:
raise IOError("Could not retrieve all required information from metadata-file '{0}'" raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'"
.format(self.metadata_fl)) .format(method, self.metadata_fl))
@staticmethod @staticmethod
def save_tf_record(output_fname, sequences, t_start_points): def save_tf_record(output_fname, sequences, t_start_points):
...@@ -114,12 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -114,12 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points : datetime type in the list, the first timestamp for each sequence t_start_points : datetime type in the list, the first timestamp for each sequence
[seq_len,height,width, channel], the len of t_start_points is the same as sequences [seq_len,height,width, channel], the len of t_start_points is the same as sequences
""" """
method = ERA5Pkl2Tfrecords.save_tf_record.__name__
sequences = np.array(sequences) sequences = np.array(sequences)
# sanity checks # sanity checks
print(t_start_points[0]) assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method)
print(type(t_start_points[0])) assert type(t_start_points[0]) == datetime.datetime, "%{0}: Elements of t_start_points must be datetime-objects.".format(method)
assert sequences.shape[0] == len(t_start_points)
assert type(t_start_points) == datetime.datetime, "What's that: {0} (type {1})".format(t_start_points[0], type(t_start_points[0]))
with tf.python_io.TFRecordWriter(output_fname) as writer: with tf.python_io.TFRecordWriter(output_fname) as writer:
for i in range(len(sequences)): for i in range(len(sequences)):
...@@ -144,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -144,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
""" """
Get normalization data class Get normalization data class
""" """
print("Make use of default minmax-normalization...") method = ERA5Pkl2Tfrecords.init_norm_class.__name__
print("%{0}: Make use of default minmax-normalization.".format(method))
# init normalization-instance # init normalization-instance
self.norm_cls = Norm_data(self.vars_in) self.norm_cls = Norm_data(self.vars_in)
self.nvars = len(self.vars_in) self.nvars = len(self.vars_in)
...@@ -162,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -162,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Return: Return:
the normalized sequences the normalized sequences
""" """
assert len(np.array(sequences).shape) == 5 method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__
assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method)
# normalization should adpot the selected variables, here we used duplicated channel temperature variables # normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences = np.array(sequences) sequences = np.array(sequences)
# normalization # normalization
...@@ -177,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -177,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
year : int, the target year to save to tfrecord year : int, the target year to save to tfrecord
month : int, the target month to save to tfrecord month : int, the target month to save to tfrecord
""" """
method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__
# Define the input_file based on the year and month # Define the input_file based on the year and month
self.input_file_year = os.path.join(self.input_dir, str(year)) self.input_file_year = os.path.join(self.input_dir, str(year))
input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month)) input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month))
...@@ -187,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -187,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points = [] t_start_points = []
sequence_iter = 0 sequence_iter = 0
# try: try:
with open(input_file, "rb") as data_file: with open(input_file, "rb") as data_file:
X_train = pickle.load(data_file) X_train = pickle.load(data_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file))
try:
with open(temp_input_file, "rb") as temp_file: with open(temp_input_file, "rb") as temp_file:
T_train = pickle.load(temp_file) T_train = pickle.load(temp_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file))
# check to make sure that X_train and T_train have the same length # check to make sure that X_train and T_train have the same length
assert (len(X_train) == len(T_train)) assert (len(X_train) == len(T_train))
...@@ -202,8 +215,6 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -202,8 +215,6 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
seq = X_train[X_start:X_end, ...] seq = X_train[X_start:X_end, ...]
# recording the start point of the timestamps (already datetime-objects) # recording the start point of the timestamps (already datetime-objects)
t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0]) t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0])
print("t_start,", t_start)
print("type of t_starty", type(t_start))
seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars)))
if not sequences: if not sequences:
last_start_sequence_iter = sequence_iter last_start_sequence_iter = sequence_iter
...@@ -221,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -221,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points) ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points)
t_start_points = [] t_start_points = []
sequences = [] sequences = []
print("Finished for input file", input_file) print("%{0}: Finished processing of input file '{1}'".format(method, input_file))
# except FileNotFoundError as fnf_error: # except FileNotFoundError as fnf_error:
# print(fnf_error) # print(fnf_error)
...@@ -232,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -232,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Function to check if the sequences has been processed. Function to check if the sequences has been processed.
If yes, the sequences are skipped, otherwise the sequences are saved to the output file If yes, the sequences are skipped, otherwise the sequences are saved to the output file
""" """
method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__
if os.path.isfile(output_fname): if os.path.isfile(output_fname):
print(output_fname, 'already exists, skip it') print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname))
else: else:
ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points) ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points)
...@@ -252,24 +265,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): ...@@ -252,24 +265,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
""" """
Wrapper to return a datetime-object Wrapper to return a datetime-object
""" """
method = ERA5Pkl2Tfrecords.ensure_datetime.__name__
fmt = "%Y%m%d %H:%M" fmt = "%Y%m%d %H:%M"
if isinstance(date, datetime.datetime): if isinstance(date, datetime.datetime):
date_new = date date_new = date
else: else:
try: try:
date_new=pd.to_datetime(date) date_new=pd.to_datetime(date)
date_new=datetime.datetime(date_new.strptime(fmt), fmt) date_new=date_new.to_pydatetime()
except Exception as err: except Exception as err:
print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date))) print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date)))
raise err raise err
return date_new return date_new
# def num_examples_per_epoch(self):
# with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
# sequence_lengths = sequence_lengths_file.readlines()
# sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
# return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def _bytes_feature(value): def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment