Select Git revision
-
Chelsea Maria John authoredChelsea Maria John authored
era5_dataset_v2.py 14.99 KiB
import argparse
import glob
import itertools
import os
import pickle
import random
import re
import hickle as hkl
import numpy as np
import json
import tensorflow as tf
from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
# ML 2020/04/14: hack for getting functions of process_netCDF_v2:
from os import path
import sys
sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
from DataPreprocess.process_netCDF_v2 import get_unique_vars
#from base_dataset import VarLenFeatureVideoDataset
from collections import OrderedDict
from tensorflow.contrib.training import HParams
class ERA5Dataset_v2(VarLenFeatureVideoDataset):
def __init__(self, *args, **kwargs):
super(ERA5Dataset_v2, self).__init__(*args, **kwargs)
from google.protobuf.json_format import MessageToDict
example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
dict_message = MessageToDict(tf.train.Example.FromString(example))
feature = dict_message['features']['feature']
self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
self.image_shape = self.video_shape[1:]
self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
def get_default_hparams_dict(self):
default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
hparams = dict(
context_frames=10,#Bing: Todo oriignal is 10
sequence_length=20,#bing: TODO original is 20,
long_sequence_length=20,
force_time_shift=True,
shuffle_on_val=True,
use_state=False,
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
@property
def jpeg_encoding(self):
return False
def num_examples_per_epoch(self):
with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
sequence_lengths = sequence_lengths_file.readlines()
sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def filter(self, serialized_example):
return tf.convert_to_tensor(True)
def make_dataset_v2(self, batch_size):
def parser(serialized_example):
seqs = OrderedDict()
keys_to_features = {
'width': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'sequence_length': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([],tf.int64),
# 'images/encoded': tf.FixedLenFeature([], tf.string)
'images/encoded': tf.VarLenFeature(tf.float32)
}
# for i in range(20):
# keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
print ("Parse features", parsed_features)
seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
# width = tf.sparse_tensor_to_dense(parsed_features["width"])
# height = tf.sparse_tensor_to_dense(parsed_features["height"])
# channels = tf.sparse_tensor_to_dense(parsed_features["channels"])
# sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
images = []
# for i in range(20):
# images.append(parsed_features["images/encoded"].values[i])
# images = parsed_features["images/encoded"]
# images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
# seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
# Parse the string into an array of pixels corresponding to the image
# images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
# images = seq
print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
seqs["images"] = images
return seqs
filenames = self.filenames
print ("FILENAMES",filenames)
#TODO:
#temporal_filenames = self.temporal_filenames
shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
if shuffle:
random.shuffle(filenames)
dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) # todo: what is buffer_size
print("files", self.filenames)
print("mode", self.mode)
dataset = dataset.filter(self.filter)
if shuffle:
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
else:
dataset = dataset.repeat(self.num_epochs)
num_parallel_calls = None if shuffle else 1
dataset = dataset.apply(tf.contrib.data.map_and_batch(
parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
#dataset = dataset.map(parser)
# num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set)
# dataset = dataset.apply(tf.contrib.data.map_and_batch(
# _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU
return dataset
def make_batch(self, batch_size):
dataset = self.make_dataset_v2(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _bytes_list_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def save_tf_record(output_fname, sequences):
print('saving sequences to %s' % output_fname)
with tf.python_io.TFRecordWriter(output_fname) as writer:
for sequence in sequences:
num_frames = len(sequence)
print('num_frames: ',str(num_frames))
print('shape of sequence: ')
print(sequence.shape)
print(sequence[0].shape)
height, width, channels = sequence[0].shape
encoded_sequence = np.array([list(image) for image in sequence])
print('encoded_sequence: '+encoded_sequence)
features = tf.train.Features(feature={
'sequence_length': _int64_feature(num_frames),
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
'images/encoded': _floats_feature(encoded_sequence.flatten()),
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
class norm_data:
### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
known_norms = {}
known_norms["minmax"] = ["min","max"]
known_norms["znorm"] = ["avg","sigma"]
def __init__(self,varnames):
varnames_uni, _, nvars = get_unique_vars(varnames)
self.varnames = varnames_uni
self.status_ok= False
def check_and_set_norm(self,stat_dict,norm):
if not norm in self.known_norms.keys():
print("Please select one of the following known normalizations: ")
for norm_avail in self.known_norms.keys():
print(norm_avail)
raise ValueError("Passed normalization '"+norm+"' is unknown.")
if not all(items in stat_dict for items in self.varnames):
print("Keys in stat_dict:")
print(stat_dict.keys())
print("Requested variables:")
print(self.varnames)
raise ValueError("Could not find all requested variables in statistics dictionary.")
for varname in self.varnames:
for stat_name in self.known_norms[norm]:
setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
self.status_ok = True
for i in range(len(self.varnames)):
print(self.varnames[i])
print(getattr(self,self.varnames[i]+"min"))
def norm_var(self,data,varname,norm):
# some sanity checks
if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")
if not norm in self.known_norms.keys():
print("Please select one of the following known normalizations: ")
for norm_avail in self.known_norms.keys():
print(norm_avail)
raise ValueError("Passed normalization '"+norm+"' is unknown.")
if norm == "minmax":
return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min")))
elif norm == "znorm":
return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2)
def denorm_var(self,data,varname,norm):
# some sanity checks
if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")
if not norm in self.known_norms.keys():
print("Please select one of the following known normalizations: ")
for norm_avail in self.known_norms.keys():
print(norm_avail)
raise ValueError("Passed normalization '"+norm+"' is unknown.")
if norm == "minmax":
return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max"))
elif norm == "znorm":
return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
# ML 2020/04/08:
# Include vars_in for more flexible data handling (normalization and reshaping)
# and optional keyword argument for kind of normalization
if 'norm' in kwargs:
norm = kwargs.get("norm")
else:
norm = "minmax"
print("Make use of default minmax-normalization...")
output_dir = os.path.join(output_dir,partition_name)
os.makedirs(output_dir,exist_ok=True)
norm_cls = norm_data(vars_in)
nvars = len(vars_in)
#vars_uni, indrev = np.unique(vars_in,return_inverse=True)
#if 'norm' in kwargs:
#norm = kwargs.get("norm")
#if (not norm in knwon_norms):
#raise ValueError("Pass valid normalization identifier.")
#print("Known identifiers are: ")
#for norm_name in known_norm:
#print('"'+norm_name+'"')
#else:
#norm = "minmax"
# open statistics file and store the dictionary
with open(os.path.join(input_dir,"statistics.json")) as js_file:
norm_cls.check_and_set_norm(json.load(js_file),norm)
#if (norm == "minmax"):
#varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in)
#print(len(varmin))
#print(varmin)
sequences = []
sequence_iter = 0
sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
X_possible_starts = [i for i in range(len(X_train) - seq_length)]
for X_start in X_possible_starts:
print("Interation", sequence_iter)
X_end = X_start + seq_length
#seq = X_train[X_start:X_end, :, :,:]
seq = X_train[X_start:X_end,:,:]
#print("*****len of seq ***.{}".format(len(seq)))
#seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
if not sequences:
last_start_sequence_iter = sequence_iter
print("reading sequences starting at sequence %d" % sequence_iter)
sequences.append(seq)
sequence_iter += 1
sequence_lengths_file.write("%d\n" % len(seq))
if len(sequences) == sequences_per_file:
###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences = np.array(sequences)
### normalization
for i in range(nvars):
sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
output_fname = os.path.join(output_dir, output_fname)
save_tf_record(output_fname, list(sequences))
sequences = []
sequence_lengths_file.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
parser.add_argument("output_dir", type=str)
# ML 2020/04/08 S
# Add vars for ensuring proper normalization and reshaping of sequences
parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
parser.add_argument("-height",type=int,default=64)
parser.add_argument("-width",type = int,default=64)
parser.add_argument("-seq_length",type=int,default=20)
args = parser.parse_args()
current_path = os.getcwd()
#input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
#output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
partition_names = ['train','val', 'test'] #64,64,3 val has issue#
for partition_name in partition_names:
read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
#ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
if __name__ == '__main__':
main()