"""
Class and functions required for preprocessing ERA5 data (preprocessing substep 2)
"""
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong"
__date__ = "2020_12_29"


# import modules
import os
import glob
import pickle
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from normalization import Norm_data
from metadata import MetaData
import datetime
from model_modules.video_prediction.datasets import ERA5Dataset


class ERA5Pkl2Tfrecords(ERA5Dataset):
    def __init__(self, input_dir=None, dest_dir=None,  sequence_length=20, sequences_per_file=128, norm="minmax"):
        """
        This class is used for converting pkl files to tfrecords
        args:
            input_dir            : str, the path to the PreprocessData directory which is parent directory of "Pickle"
                                   and "tfrecords" files directiory.
            sequence_length      : int, default is 20, the sequen length per sample
            sequences_per_file   : int, how many sequences/samples per tfrecord to be saved
            norm                 : str, normalization methods from Norm_data class ("minmax" or "znorm";
                                   default: "minmax")
        """
        self.input_dir = input_dir
        self.output_dir = dest_dir
        # if the output_dir does not exist, then create it
        os.makedirs(self.output_dir, exist_ok=True)
        # get metadata,includes the var_in, image height, width etc.
        self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json")
        self.get_metadata(MetaData(json_file=self.metadata_fl))
        # Get the data split informaiton
        self.sequence_length = sequence_length
        if norm == "minmax" or norm == "znorm":
            self.norm = norm
        else:
            raise ValueError("norm should be either 'minmax' or 'znorm'")
        self.sequences_per_file = sequences_per_file
        self.write_sequence_file()

    def get_years_months(self):
        """
        Get the months in the datasplit_config
        Return : 
                two elements: each contains 1-dim array with the months set from data_split_config json file
        """
        self.months = []
        self.years_months = []
        # search for pickle names with pattern 'X_{}.pkl'for months
        self.years = [name for name in os.listdir(self.input_dir) if os.path.isdir(os.path.join(self.input_dir, name))]
        # search for folder names from pickle folder to get years
        patt = "X_*.pkl"         
        for year in self.years:
            months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt))
            months_list = [int(m[-6:-4]) for m in months_pkl_list]
            self.months.extend(months_list)
            self.years_months.append(months_list)
        return self.years, list(set(self.months)), self.years_months

    def get_stats_file(self):
        """
        Get the corresponding statistics file
        """
        method = ERA5Pkl2Tfrecords.get_stats_file.__name__

        stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json")
        print("Opening json-file: {0}".format(stats_file))
        if os.path.isfile(stats_file):
            with open(stats_file) as js_file:
                self.stats = json.load(js_file)
        else:
            raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file))

    def get_metadata(self, md_instance):
        """
        This function gets the meta data that has been generated in data_process_step1. Here, we aim to extract
        the height and width information from it
        vars_in   : list(str), must be consistent with the list from DataPreprocessing_step1
        height    : int, the height of the image
        width     : int, the width of the image
        """
        method = ERA5Pkl2Tfrecords.get_metadata.__name__
        
        if not isinstance(md_instance, MetaData):
            raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method))

        if not hasattr(self, "metadata_fl"):
            raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method))

        try:
            self.height, self.width = md_instance.ny, md_instance.nx
            self.vars_in = md_instance.variables
        except:
            raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'"
                          .format(method, self.metadata_fl))

    @staticmethod
    def save_tf_record(output_fname, sequences, t_start_points):
        """
        Save the sequences, and the corresponding timestamp start point to tfrecords
        args:
            output_frames    : str, the file names of the output
            sequences        : list or array, the sequences want to be saved to tfrecords,
                               [sequences,seq_len,height,width,channels]
            t_start_points   : datetime type in the list, the first timestamp for each sequence
                               [seq_len,height,width, channel], the len of t_start_points is the same as sequences
        """
        method = ERA5Pkl2Tfrecords.save_tf_record.__name__

        sequences = np.array(sequences)
        # sanity checks
        assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method)
        assert isinstance(t_start_points[0], datetime.datetime), "%{0}: Elements of t_start_points must be datetime-objects.".format(method)

        with tf.python_io.TFRecordWriter(output_fname) as writer:
            for i in range(len(sequences)):
                sequence = sequences[i]

                t_start = t_start_points[i].strftime("%Y%m%d%H")
                num_frames = len(sequence)
                height, width, channels = sequence[0].shape
                encoded_sequence = np.array([list(image) for image in sequence])
                features = tf.train.Features(feature={
                    'sequence_length': _int64_feature(num_frames),
                    'height': _int64_feature(height),
                    'width': _int64_feature(width),
                    'channels': _int64_feature(channels),
                    't_start': _int64_feature(int(t_start)),
                    'images/encoded': _floats_feature(encoded_sequence.flatten()),
                })
                example = tf.train.Example(features=features)
                writer.write(example.SerializeToString())

    def init_norm_class(self):
        """
        Get normalization data class 
        """
        method = ERA5Pkl2Tfrecords.init_norm_class.__name__

        print("%{0}: Make use of default minmax-normalization.".format(method))
        # init normalization-instance
        self.norm_cls = Norm_data(self.vars_in)
        self.nvars = len(self.vars_in)
        # get statistics file
        self.get_stats_file()
        # open statistics file and feed it to norm-instance
        self.norm_cls.check_and_set_norm(self.stats, self.norm)

    def normalize_vars_per_seq(self, sequences):
        """
        Normalize all the variables for the sequences
        args:
            sequences: list or array, is the sequences need to be saved to tfrecorcd.
                       The shape should be [sequences_per_file,seq_length,height,width,nvars]
        Return:
            the normalized sequences
        """
        method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__

        assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method)
        # normalization should adpot the selected variables, here we used duplicated channel temperature variables
        sequences = np.array(sequences)
        # normalization
        for i in range(self.nvars):
            sequences[..., i] = self.norm_cls.norm_var(sequences[..., i], self.vars_in[i], self.norm)
        return sequences

    def read_pkl_and_save_tfrecords(self, year, month):
        """
        Read pickle files based on month, to process and save to tfrecords,
        args:
            year    : int, the target year to save to tfrecord
            month   : int, the target month to save to tfrecord 
        """
        method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__

        # Define the input_file based on the year and month
        self.input_file_year = os.path.join(self.input_dir, str(year))
        input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month))
        temp_input_file = os.path.join(self.input_file_year, 'T_{:02d}.pkl'.format(month))

        self.init_norm_class()
        sequences = []
        t_start_points = []
        sequence_iter = 0

        try:
            with open(input_file, "rb") as data_file:
                X_train = pickle.load(data_file)
        except:
            raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file))

        try:
            with open(temp_input_file, "rb") as temp_file:
                T_train = pickle.load(temp_file)
        except:
            raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file))

        # check to make sure that X_train and T_train have the same length
        assert (len(X_train) == len(T_train))

        X_possible_starts = [i for i in range(len(X_train) - self.sequence_length)]
        for X_start in X_possible_starts:
            X_end = X_start + self.sequence_length
            seq = X_train[X_start:X_end, ...]
            # recording the start point of the timestamps (already datetime-objects)
            t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0])
            seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars)))
            if not sequences:
                last_start_sequence_iter = sequence_iter
            sequences.append(seq)
            t_start_points.append(t_start)
            sequence_iter += 1

            if len(sequences) == self.sequences_per_file:
                # normalize variables in the sequences
                sequences = ERA5Pkl2Tfrecords.normalize_vars_per_seq(self, sequences)
                output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year, month, last_start_sequence_iter,
                                                                              sequence_iter - 1)
                output_fname = os.path.join(self.output_dir, output_fname)
                # write to tfrecord
                ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points)
                t_start_points = []
                sequences = []
        print("%{0}: Finished processing of input file '{1}'".format(method, input_file))

#         except FileNotFoundError as fnf_error:
#             print(fnf_error)

    @staticmethod
    def write_seq_to_tfrecord(output_fname, sequences, t_start_points):
        """
        Function to check if the sequences has been processed.
        If yes, the sequences are skipped, otherwise the sequences are saved to the output file
        """
        method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__

        if os.path.isfile(output_fname):
            print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname))
        else:
            ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points)

    def write_sequence_file(self):
        """
        Generate a txt file, with the numbers of sequences for each tfrecords file.
        This is mainly used for calculting the number of samples for each epoch during training epoch
        """

        with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file:
            seq_file.write("%d\n" % self.sequences_per_file)


    @staticmethod
    def ensure_datetime(date):
        """
        Wrapper to return a datetime-object
        """
        method = ERA5Pkl2Tfrecords.ensure_datetime.__name__

        fmt = "%Y%m%d %H:%M"
        if isinstance(date, datetime.datetime):
            date_new = date
        else:
            try:
                date_new=pd.to_datetime(date)
                date_new=date_new.to_pydatetime()
            except Exception as err:
                print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date)))
                raise err

        return date_new

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bytes_list_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))


def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))