Skip to content
Snippets Groups Projects
Commit 5c0f3e79 authored by Michael Langguth's avatar Michael Langguth
Browse files

Corrected handling of t_start_points in preprocess_data_step2.py and source-code style changes.

parent ff258b1d
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ import os
import glob
import pickle
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from normalization import Norm_data
......@@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
default: "minmax")
"""
self.input_dir = input_dir
# ML: No hidden path-extensions (rather managed in generate_runscript.py)
# self.input_dir_pkl = os.path.join(input_dir,"pickle")
self.output_dir = dest_dir
# if the output_dir is not exist, then create it
# if the output_dir does not exist, then create it
os.makedirs(self.output_dir, exist_ok=True)
# get metadata,includes the var_in, image height, width etc.
self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json")
......@@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
# search for folder names from pickle folder to get years
patt = "X_*.pkl"
for year in self.years:
print("pahtL:", os.path.join(self.input_dir, year, patt))
months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt))
print("months_pkl_list", months_pkl_list)
months_list = [int(m[-6:-4]) for m in months_pkl_list]
self.months.extend(months_list)
self.years_months.append(months_list)
......@@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Get the corresponding statistics file
"""
self.stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json")
print("Opening json-file: {0}".format(self.stats_file))
if os.path.isfile(self.stats_file):
with open(self.stats_file) as js_file:
method = ERA5Pkl2Tfrecords.get_stat_file.__name__
stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json")
print("Opening json-file: {0}".format(stats_file))
if os.path.isfile(stats_file):
with open(stats_file) as js_file:
self.stats = json.load(js_file)
else:
raise FileNotFoundError("Statistic file does not exist")
raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file))
def get_metadata(self, md_instance):
"""
......@@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
height : int, the height of the image
width : int, the width of the image
"""
method = ERA5Pkl2Tfrecords.get_metadata.__name__
if not isinstance(md_instance, MetaData):
raise ValueError("md_instance-argument must be a MetaData class instance")
raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method))
if not hasattr(self, "metadata_fl"):
raise ValueError("MetaData class instance passed, but attribute metadata_fl is still missing.")
raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method))
try:
self.height, self.width = md_instance.ny, md_instance.nx
self.vars_in = md_instance.variables
except:
raise IOError("Could not retrieve all required information from metadata-file '{0}'"
.format(self.metadata_fl))
raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'"
.format(method, self.metadata_fl))
@staticmethod
def save_tf_record(output_fname, sequences, t_start_points):
......@@ -114,12 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points : datetime type in the list, the first timestamp for each sequence
[seq_len,height,width, channel], the len of t_start_points is the same as sequences
"""
method = ERA5Pkl2Tfrecords.save_tf_record.__name__
sequences = np.array(sequences)
# sanity checks
print(t_start_points[0])
print(type(t_start_points[0]))
assert sequences.shape[0] == len(t_start_points)
assert type(t_start_points) == datetime.datetime, "What's that: {0} (type {1})".format(t_start_points[0], type(t_start_points[0]))
assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method)
assert type(t_start_points[0]) == datetime.datetime, "%{0}: Elements of t_start_points must be datetime-objects.".format(method)
with tf.python_io.TFRecordWriter(output_fname) as writer:
for i in range(len(sequences)):
......@@ -144,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Get normalization data class
"""
print("Make use of default minmax-normalization...")
method = ERA5Pkl2Tfrecords.init_norm_class.__name__
print("%{0}: Make use of default minmax-normalization.".format(method))
# init normalization-instance
self.norm_cls = Norm_data(self.vars_in)
self.nvars = len(self.vars_in)
......@@ -162,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Return:
the normalized sequences
"""
assert len(np.array(sequences).shape) == 5
method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__
assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method)
# normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences = np.array(sequences)
# normalization
......@@ -177,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
year : int, the target year to save to tfrecord
month : int, the target month to save to tfrecord
"""
method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__
# Define the input_file based on the year and month
self.input_file_year = os.path.join(self.input_dir, str(year))
input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month))
......@@ -187,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points = []
sequence_iter = 0
# try:
try:
with open(input_file, "rb") as data_file:
X_train = pickle.load(data_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file))
try:
with open(temp_input_file, "rb") as temp_file:
T_train = pickle.load(temp_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file))
# check to make sure that X_train and T_train have the same length
assert (len(X_train) == len(T_train))
......@@ -202,8 +215,6 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
seq = X_train[X_start:X_end, ...]
# recording the start point of the timestamps (already datetime-objects)
t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0])
print("t_start,", t_start)
print("type of t_starty", type(t_start))
seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars)))
if not sequences:
last_start_sequence_iter = sequence_iter
......@@ -221,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points)
t_start_points = []
sequences = []
print("Finished for input file", input_file)
print("%{0}: Finished processing of input file '{1}'".format(method, input_file))
# except FileNotFoundError as fnf_error:
# print(fnf_error)
......@@ -232,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Function to check if the sequences has been processed.
If yes, the sequences are skipped, otherwise the sequences are saved to the output file
"""
method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__
if os.path.isfile(output_fname):
print(output_fname, 'already exists, skip it')
print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname))
else:
ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points)
......@@ -252,24 +265,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Wrapper to return a datetime-object
"""
method = ERA5Pkl2Tfrecords.ensure_datetime.__name__
fmt = "%Y%m%d %H:%M"
if isinstance(date, datetime.datetime):
date_new = date
else:
try:
date_new=pd.to_datetime(date)
date_new=datetime.datetime(date_new.strptime(fmt), fmt)
date_new=date_new.to_pydatetime()
except Exception as err:
print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date)))
raise err
return date_new
# def num_examples_per_epoch(self):
# with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
# sequence_lengths = sequence_lengths_file.readlines()
# sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
# return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment