Skip to content
Snippets Groups Projects
Select Git revision
  • 46dc86ef04b130c9780dbe228ec839f33364c4db
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
27 results

hyperparam_setup.sh

Blame
  • era5_dataset_v2.py 14.99 KiB
    import argparse
    import glob
    import itertools
    import os
    import pickle
    import random
    import re
    import hickle as hkl
    import numpy as np
    import json
    import tensorflow as tf
    from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
    # ML 2020/04/14: hack for getting functions of process_netCDF_v2:
    from os import path
    import sys
    sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
    from DataPreprocess.process_netCDF_v2 import get_unique_vars
    #from base_dataset import VarLenFeatureVideoDataset
    from collections import OrderedDict
    from tensorflow.contrib.training import HParams
    
    class ERA5Dataset_v2(VarLenFeatureVideoDataset):
        def __init__(self, *args, **kwargs):
            super(ERA5Dataset_v2, self).__init__(*args, **kwargs)
            from google.protobuf.json_format import MessageToDict
            example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
            dict_message = MessageToDict(tf.train.Example.FromString(example))
            feature = dict_message['features']['feature']
            self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
            self.image_shape = self.video_shape[1:]
            self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
    
        def get_default_hparams_dict(self):
            default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
            hparams = dict(
                context_frames=10,#Bing: Todo oriignal is 10
                sequence_length=20,#bing: TODO original is 20,
                long_sequence_length=20,
                force_time_shift=True,
                shuffle_on_val=True, 
                use_state=False,
            )
            return dict(itertools.chain(default_hparams.items(), hparams.items()))
    
    
        @property
        def jpeg_encoding(self):
            return False
    
    
    
        def num_examples_per_epoch(self):
            with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
                sequence_lengths = sequence_lengths_file.readlines()
            sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
            return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
    
    
        def filter(self, serialized_example):
            return tf.convert_to_tensor(True)
    
    
        def make_dataset_v2(self, batch_size):
            def parser(serialized_example):
                seqs = OrderedDict()
                keys_to_features = {
                    'width': tf.FixedLenFeature([], tf.int64),
                    'height': tf.FixedLenFeature([], tf.int64),
                    'sequence_length': tf.FixedLenFeature([], tf.int64),
                    'channels': tf.FixedLenFeature([],tf.int64),
                    # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                    'images/encoded': tf.VarLenFeature(tf.float32)
                }
                
                # for i in range(20):
                #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
                parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
                print ("Parse features", parsed_features)
                seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
               # width = tf.sparse_tensor_to_dense(parsed_features["width"])
               # height = tf.sparse_tensor_to_dense(parsed_features["height"])
               # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
               # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
                images = []
                # for i in range(20):
                #    images.append(parsed_features["images/encoded"].values[i])
                # images = parsed_features["images/encoded"]
                # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
                # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
                # Parse the string into an array of pixels corresponding to the image
                # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
    
                # images = seq
                print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
                images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
                seqs["images"] = images
                return seqs
            filenames = self.filenames
            print ("FILENAMES",filenames)
    	#TODO:
    	#temporal_filenames = self.temporal_filenames
            shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
            if shuffle:
                random.shuffle(filenames)
            dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
            print("files", self.filenames)
            print("mode", self.mode)
            dataset = dataset.filter(self.filter)
            if shuffle:
                dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
            else:
                dataset = dataset.repeat(self.num_epochs)
    
            num_parallel_calls = None if shuffle else 1
            dataset = dataset.apply(tf.contrib.data.map_and_batch(
                parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
            #dataset = dataset.map(parser)
            # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
            # dataset = dataset.apply(tf.contrib.data.map_and_batch(
            #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
            dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
    
            return dataset
    
    
    
        def make_batch(self, batch_size):
            dataset = self.make_dataset_v2(batch_size)
            iterator = dataset.make_one_shot_iterator()
            return iterator.get_next()
    
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def _bytes_list_feature(values):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
    
    def _floats_feature(value):
      return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    def save_tf_record(output_fname, sequences):
        print('saving sequences to %s' % output_fname)
        with tf.python_io.TFRecordWriter(output_fname) as writer:
            for sequence in sequences:
                num_frames = len(sequence)
                print('num_frames: ',str(num_frames))
                print('shape of sequence: ')
                print(sequence.shape)
                print(sequence[0].shape)
                height, width, channels = sequence[0].shape
                encoded_sequence = np.array([list(image) for image in sequence])
                print('encoded_sequence: '+encoded_sequence)
    
                features = tf.train.Features(feature={
                    'sequence_length': _int64_feature(num_frames),
                    'height': _int64_feature(height),
                    'width': _int64_feature(width),
                    'channels': _int64_feature(channels),
                    'images/encoded': _floats_feature(encoded_sequence.flatten()),
                })
                example = tf.train.Example(features=features)
                writer.write(example.SerializeToString())
                
    class norm_data:
        
        ### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
        known_norms = {}
        known_norms["minmax"] = ["min","max"]
        known_norms["znorm"]  = ["avg","sigma"]
        
        def __init__(self,varnames):
            varnames_uni, _, nvars = get_unique_vars(varnames)
            
            self.varnames = varnames_uni
            self.status_ok= False
                
        def check_and_set_norm(self,stat_dict,norm):
            
            if not norm in self.known_norms.keys():
                print("Please select one of the following known normalizations: ")
                for norm_avail in self.known_norms.keys():
                    print(norm_avail)
                raise ValueError("Passed normalization '"+norm+"' is unknown.")
           
            if not all(items in stat_dict for items in self.varnames):
                print("Keys in stat_dict:")
                print(stat_dict.keys())
                
                print("Requested variables:")
                print(self.varnames)
                raise ValueError("Could not find all requested variables in statistics dictionary.")   
    
            for varname in self.varnames:
                for stat_name in self.known_norms[norm]:
                    setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
                    
            self.status_ok = True
            for i in range(len(self.varnames)):
                print(self.varnames[i])
                print(getattr(self,self.varnames[i]+"min"))
                    
        def norm_var(self,data,varname,norm):
            
            # some sanity checks
            if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")
            
            if not norm in self.known_norms.keys():
                print("Please select one of the following known normalizations: ")
                for norm_avail in self.known_norms.keys():
                    print(norm_avail)
                raise ValueError("Passed normalization '"+norm+"' is unknown.")
            
            if norm == "minmax":
                return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min")))
            elif norm == "znorm":
                return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2)
            
        def denorm_var(self,data,varname,norm):
            
            # some sanity checks
            if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")        
            
            if not norm in self.known_norms.keys():
                print("Please select one of the following known normalizations: ")
                for norm_avail in self.known_norms.keys():
                    print(norm_avail)
                raise ValueError("Passed normalization '"+norm+"' is unknown.")
            
            if norm == "minmax":
                return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max"))
            elif norm == "znorm":
                return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
            
    
    def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
        # ML 2020/04/08:
        # Include vars_in for more flexible data handling (normalization and reshaping)
        # and optional keyword argument for kind of normalization
        
        if 'norm' in kwargs:
            norm = kwargs.get("norm")
        else:
            norm = "minmax"
            print("Make use of default minmax-normalization...")
    
        output_dir = os.path.join(output_dir,partition_name)
        os.makedirs(output_dir,exist_ok=True)
        
        norm_cls  = norm_data(vars_in)
        nvars     = len(vars_in)
        #vars_uni, indrev = np.unique(vars_in,return_inverse=True)
        #if 'norm' in kwargs:
            #norm = kwargs.get("norm")
            #if (not norm in knwon_norms): 
                #raise ValueError("Pass valid normalization identifier.")
                #print("Known identifiers are: ")
                #for norm_name in known_norm:
                    #print('"'+norm_name+'"')
        #else:
            #norm = "minmax"
        
        # open statistics file and store the dictionary
        with open(os.path.join(input_dir,"statistics.json")) as js_file:
            norm_cls.check_and_set_norm(json.load(js_file),norm)        
        
            #if (norm == "minmax"):
                #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in)
    
        #print(len(varmin))
        #print(varmin)
        
        sequences = []
        sequence_iter = 0
        sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
        X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
        X_possible_starts = [i for i in range(len(X_train) - seq_length)]
        for X_start in X_possible_starts:
            print("Interation", sequence_iter)
            X_end = X_start + seq_length
            #seq = X_train[X_start:X_end, :, :,:]
            seq = X_train[X_start:X_end,:,:]
            #print("*****len of seq ***.{}".format(len(seq)))
            #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
            seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
            if not sequences:
                last_start_sequence_iter = sequence_iter
                print("reading sequences starting at sequence %d" % sequence_iter)
            sequences.append(seq)
            sequence_iter += 1
            sequence_lengths_file.write("%d\n" % len(seq))
    
            if len(sequences) == sequences_per_file:
                ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
                sequences = np.array(sequences)
                ### normalization
                for i in range(nvars):    
                    sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)
    
                output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
                output_fname = os.path.join(output_dir, output_fname)
                save_tf_record(output_fname, list(sequences))
                sequences = []
        sequence_lengths_file.close()
    
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
        parser.add_argument("output_dir", type=str)
        # ML 2020/04/08 S
        # Add vars for ensuring proper normalization and reshaping of sequences
        parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
        parser.add_argument("-height",type=int,default=64)
        parser.add_argument("-width",type = int,default=64)
        parser.add_argument("-seq_length",type=int,default=20)
        args = parser.parse_args()
        current_path = os.getcwd()
        #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
        #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
        partition_names = ['train','val',  'test'] #64,64,3 val has issue#
      
        for partition_name in partition_names:
            read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
            #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
    
    if __name__ == '__main__':
        main()