import argparse
import glob
import itertools
import os
import pickle
import random
import re
import hickle as hkl
import numpy as np
import json
import tensorflow as tf
from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
# ML 2020/04/14: hack for getting functions of process_netCDF_v2:
from os import path
import sys
sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/'))
from DataPreprocess.process_netCDF_v2 import get_unique_vars
#from base_dataset import VarLenFeatureVideoDataset
from collections import OrderedDict
from tensorflow.contrib.training import HParams

class ERA5Dataset_v2(VarLenFeatureVideoDataset):
    def __init__(self, *args, **kwargs):
        super(ERA5Dataset_v2, self).__init__(*args, **kwargs)
        from google.protobuf.json_format import MessageToDict
        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
        dict_message = MessageToDict(tf.train.Example.FromString(example))
        feature = dict_message['features']['feature']
        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
        self.image_shape = self.video_shape[1:]
        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape

    def get_default_hparams_dict(self):
        default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
        hparams = dict(
            context_frames=10,#Bing: Todo oriignal is 10
            sequence_length=20,#bing: TODO original is 20,
            long_sequence_length=20,
            force_time_shift=True,
            shuffle_on_val=True, 
            use_state=False,
        )
        return dict(itertools.chain(default_hparams.items(), hparams.items()))


    @property
    def jpeg_encoding(self):
        return False



    def num_examples_per_epoch(self):
        with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
            sequence_lengths = sequence_lengths_file.readlines()
        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)


    def filter(self, serialized_example):
        return tf.convert_to_tensor(True)


    def make_dataset_v2(self, batch_size):
        def parser(serialized_example):
            seqs = OrderedDict()
            keys_to_features = {
                'width': tf.FixedLenFeature([], tf.int64),
                'height': tf.FixedLenFeature([], tf.int64),
                'sequence_length': tf.FixedLenFeature([], tf.int64),
                'channels': tf.FixedLenFeature([],tf.int64),
                # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                'images/encoded': tf.VarLenFeature(tf.float32)
            }
            
            # for i in range(20):
            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
            parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
            print ("Parse features", parsed_features)
            seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
           # width = tf.sparse_tensor_to_dense(parsed_features["width"])
           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
            images = []
            # for i in range(20):
            #    images.append(parsed_features["images/encoded"].values[i])
            # images = parsed_features["images/encoded"]
            # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets)
            # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '')
            # Parse the string into an array of pixels corresponding to the image
            # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)

            # images = seq
            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
            seqs["images"] = images
            return seqs
        filenames = self.filenames
        print ("FILENAMES",filenames)
	#TODO:
	#temporal_filenames = self.temporal_filenames
        shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
        if shuffle:
            random.shuffle(filenames)
        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
        print("files", self.filenames)
        print("mode", self.mode)
        dataset = dataset.filter(self.filter)
        if shuffle:
            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
        else:
            dataset = dataset.repeat(self.num_epochs)

        num_parallel_calls = None if shuffle else 1
        dataset = dataset.apply(tf.contrib.data.map_and_batch(
            parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
        #dataset = dataset.map(parser)
        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU

        return dataset



    def make_batch(self, batch_size):
        dataset = self.make_dataset_v2(batch_size)
        iterator = dataset.make_one_shot_iterator()
        return iterator.get_next()

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bytes_list_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))

def _floats_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def save_tf_record(output_fname, sequences):
    print('saving sequences to %s' % output_fname)
    with tf.python_io.TFRecordWriter(output_fname) as writer:
        for sequence in sequences:
            num_frames = len(sequence)
            print('num_frames: ',str(num_frames))
            print('shape of sequence: ')
            print(sequence.shape)
            print(sequence[0].shape)
            height, width, channels = sequence[0].shape
            encoded_sequence = np.array([list(image) for image in sequence])
            print('encoded_sequence: '+encoded_sequence)

            features = tf.train.Features(feature={
                'sequence_length': _int64_feature(num_frames),
                'height': _int64_feature(height),
                'width': _int64_feature(width),
                'channels': _int64_feature(channels),
                'images/encoded': _floats_feature(encoded_sequence.flatten()),
            })
            example = tf.train.Example(features=features)
            writer.write(example.SerializeToString())
            
class norm_data:
    
    ### set known norms and the requested statistics (to be retrieved from statistics.json) here ###
    known_norms = {}
    known_norms["minmax"] = ["min","max"]
    known_norms["znorm"]  = ["avg","sigma"]
    
    def __init__(self,varnames):
        varnames_uni, _, nvars = get_unique_vars(varnames)
        
        self.varnames = varnames_uni
        self.status_ok= False
            
    def check_and_set_norm(self,stat_dict,norm):
        
        if not norm in self.known_norms.keys():
            print("Please select one of the following known normalizations: ")
            for norm_avail in self.known_norms.keys():
                print(norm_avail)
            raise ValueError("Passed normalization '"+norm+"' is unknown.")
       
        if not all(items in stat_dict for items in self.varnames):
            print("Keys in stat_dict:")
            print(stat_dict.keys())
            
            print("Requested variables:")
            print(self.varnames)
            raise ValueError("Could not find all requested variables in statistics dictionary.")   

        for varname in self.varnames:
            for stat_name in self.known_norms[norm]:
                setattr(self,varname+stat_name,stat_dict[varname][0][stat_name])
                
        self.status_ok = True
        for i in range(len(self.varnames)):
            print(self.varnames[i])
            print(getattr(self,self.varnames[i]+"min"))
                
    def norm_var(self,data,varname,norm):
        
        # some sanity checks
        if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")
        
        if not norm in self.known_norms.keys():
            print("Please select one of the following known normalizations: ")
            for norm_avail in self.known_norms.keys():
                print(norm_avail)
            raise ValueError("Passed normalization '"+norm+"' is unknown.")
        
        if norm == "minmax":
            return((data[...] - getattr(self,varname+"min"))/(getattr(self,varname+"max") - getattr(self,varname+"min")))
        elif norm == "znorm":
            return((data[...] - getattr(self,varname+"avg"))/getattr(self,varname+"sigma")**2)
        
    def denorm_var(self,data,varname,norm):
        
        # some sanity checks
        if not self.status_ok: raise ValueError("norm_data-object needs to be initialized and checked first.")        
        
        if not norm in self.known_norms.keys():
            print("Please select one of the following known normalizations: ")
            for norm_avail in self.known_norms.keys():
                print(norm_avail)
            raise ValueError("Passed normalization '"+norm+"' is unknown.")
        
        if norm == "minmax":
            return(data[...] * (getattr(self,varname+"max") - getattr(self,varname+"min")) + getattr(self,varname+"max"))
        elif norm == "znorm":
            return(data[...] * getattr(self,varname+"sigma")**2 + getattr(self,varname+"avg"))
        

def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
    # ML 2020/04/08:
    # Include vars_in for more flexible data handling (normalization and reshaping)
    # and optional keyword argument for kind of normalization
    
    if 'norm' in kwargs:
        norm = kwargs.get("norm")
    else:
        norm = "minmax"
        print("Make use of default minmax-normalization...")

    output_dir = os.path.join(output_dir,partition_name)
    os.makedirs(output_dir,exist_ok=True)
    
    norm_cls  = norm_data(vars_in)
    nvars     = len(vars_in)
    #vars_uni, indrev = np.unique(vars_in,return_inverse=True)
    #if 'norm' in kwargs:
        #norm = kwargs.get("norm")
        #if (not norm in knwon_norms): 
            #raise ValueError("Pass valid normalization identifier.")
            #print("Known identifiers are: ")
            #for norm_name in known_norm:
                #print('"'+norm_name+'"')
    #else:
        #norm = "minmax"
    
    # open statistics file and store the dictionary
    with open(os.path.join(input_dir,"statistics.json")) as js_file:
        norm_cls.check_and_set_norm(json.load(js_file),norm)        
    
        #if (norm == "minmax"):
            #varmin, varmax = get_stat_allvars(data,"min",vars_in), get_stat_allvars(data,"max",vars_in)

    #print(len(varmin))
    #print(varmin)
    
    sequences = []
    sequence_iter = 0
    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
    X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
    X_possible_starts = [i for i in range(len(X_train) - seq_length)]
    for X_start in X_possible_starts:
        print("Interation", sequence_iter)
        X_end = X_start + seq_length
        #seq = X_train[X_start:X_end, :, :,:]
        seq = X_train[X_start:X_end,:,:]
        #print("*****len of seq ***.{}".format(len(seq)))
        #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
        seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
        if not sequences:
            last_start_sequence_iter = sequence_iter
            print("reading sequences starting at sequence %d" % sequence_iter)
        sequences.append(seq)
        sequence_iter += 1
        sequence_lengths_file.write("%d\n" % len(seq))

        if len(sequences) == sequences_per_file:
            ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables
            sequences = np.array(sequences)
            ### normalization
            for i in range(nvars):    
                sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm)

            output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
            output_fname = os.path.join(output_dir, output_fname)
            save_tf_record(output_fname, list(sequences))
            sequences = []
    sequence_lengths_file.close()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
    parser.add_argument("output_dir", type=str)
    # ML 2020/04/08 S
    # Add vars for ensuring proper normalization and reshaping of sequences
    parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
    parser.add_argument("-height",type=int,default=64)
    parser.add_argument("-width",type = int,default=64)
    parser.add_argument("-seq_length",type=int,default=20)
    args = parser.parse_args()
    current_path = os.getcwd()
    #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
    #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5"
    partition_names = ['train','val',  'test'] #64,64,3 val has issue#
  
    for partition_name in partition_names:
        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
        #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5

if __name__ == '__main__':
    main()