Skip to content
Snippets Groups Projects
Select Git revision
  • 02055510025f9bc5e40a56d0fbe56372f1e87cc9
  • master default
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
28 results

DataPreprocess_to_tf.sh

Blame
  • generate_transfer_learning_finetune.py 33.17 KiB
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import errno
    import os
    import math
    import tensorflow as tf
    import random
    import json
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec
    from video_prediction import datasets, models
    from skimage.metrics import structural_similarity as ssim
    import pickle
    
    with open("geo_info.json","r") as json_file:
        geo = json.load(json_file)
        lat = [round(i,2) for i in geo["lat"]]
        lon = [round(i,2) for i in geo["lon"]]
    
    
    def psnr(img1, img2):
        mse = np.mean((img1 - img2) ** 2)
        if mse == 0: return 100
        PIXEL_MAX = 1
        return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
    
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("--input_dir", type = str, required = True,
                            help = "either a directory containing subdirectories "
                                   "train, val, test, etc, or a directory containing "
                                   "the tfrecords")
        parser.add_argument("--results_dir", type = str, default = 'results',
                            help = "ignored if output_gif_dir is specified")
        parser.add_argument("--results_gif_dir", type = str,
                            help = "default is results_dir. ignored if output_gif_dir is specified")
        parser.add_argument("--results_png_dir", type = str,
                            help = "default is results_dir. ignored if output_png_dir is specified")
        parser.add_argument("--output_gif_dir", help = "output directory where samples are saved as gifs. default is "
                                                       "results_gif_dir/model_fname")
        parser.add_argument("--output_png_dir", help = "output directory where samples are saved as pngs. default is "
                                                       "results_png_dir/model_fname")
        parser.add_argument("--checkpoint",
                            help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
    
        parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val',
                            help = 'mode for dataset, val or test.')
    
        parser.add_argument("--dataset", type = str, help = "dataset class name")
        parser.add_argument("--dataset_hparams", type = str,
                            help = "a string of comma separated list of dataset hyperparameters")
        parser.add_argument("--model", type = str, help = "model class name")
        parser.add_argument("--model_hparams", type = str,
                            help = "a string of comma separated list of model hyperparameters")
    
        parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch")
        parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)")
        parser.add_argument("--num_epochs", type = int, default = 1)
    
        parser.add_argument("--num_stochastic_samples", type = int, default = 1)
        parser.add_argument("--gif_length", type = int, help = "default is sequence_length")
        parser.add_argument("--fps", type = int, default = 4)
    
        parser.add_argument("--gpu_mem_frac", type = float, default = 0, help = "fraction of gpu memory to use")
        parser.add_argument("--seed", type = int, default = 7)
    
        args = parser.parse_args()
    
        if args.seed is not None:
            tf.set_random_seed(args.seed)
            np.random.seed(args.seed)
            random.seed(args.seed)
    
        args.results_gif_dir = args.results_gif_dir or args.results_dir
        args.results_png_dir = args.results_png_dir or args.results_dir
        dataset_hparams_dict = {}
        model_hparams_dict = {}
        if args.checkpoint:
            checkpoint_dir = os.path.normpath(args.checkpoint)
            if not os.path.isdir(args.checkpoint):
                checkpoint_dir, _ = os.path.split(checkpoint_dir)
            if not os.path.exists(checkpoint_dir):
                raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
            with open(os.path.join(checkpoint_dir, "options.json")) as f:
                print("loading options from checkpoint %s" % args.checkpoint)
                options = json.loads(f.read())
                args.dataset = args.dataset or options['dataset']
                args.model = args.model or options['model']
            try:
                with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                    dataset_hparams_dict = json.loads(f.read())
            except FileNotFoundError:
                print("dataset_hparams.json was not loaded because it does not exist")
            try:
                with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                    model_hparams_dict = json.loads(f.read())
            except FileNotFoundError:
                print("model_hparams.json was not loaded because it does not exist")
            args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir,
                                                                      os.path.split(checkpoint_dir)[1])
            args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir,
                                                                      os.path.split(checkpoint_dir)[1])
        else:
            if not args.dataset:
                raise ValueError('dataset is required when checkpoint is not specified')
            if not args.model:
                raise ValueError('model is required when checkpoint is not specified')
            args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model)
            args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model)
    
        print('----------------------------------- Options ------------------------------------')
        for k, v in args._get_kwargs():
            print(k, "=", v)
        print('------------------------------------- End --------------------------------------')
    
        VideoDataset = datasets.get_dataset_class(args.dataset)
        dataset = VideoDataset(
            args.input_dir,
            mode = args.mode,
            num_epochs = args.num_epochs,
            seed = args.seed,
            hparams_dict = dataset_hparams_dict,
            hparams = args.dataset_hparams)
    
        VideoPredictionModel = models.get_model_class(args.model)
        hparams_dict = dict(model_hparams_dict)
        hparams_dict.update({
            'context_frames': dataset.hparams.context_frames,
            'sequence_length': dataset.hparams.sequence_length,
            'repeat': dataset.hparams.time_shift,
        })
        model = VideoPredictionModel(
            mode = args.mode,
            hparams_dict = hparams_dict,
            hparams = args.model_hparams)
    
        sequence_length = model.hparams.sequence_length
        context_frames = model.hparams.context_frames
        future_length = sequence_length - context_frames
    
        if args.num_samples:
            if args.num_samples > dataset.num_examples_per_epoch():
                raise ValueError('num_samples cannot be larger than the dataset')
            num_examples_per_epoch = args.num_samples
        else:
            num_examples_per_epoch = dataset.num_examples_per_epoch()
        if num_examples_per_epoch % args.batch_size != 0:
            raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch)
    
        inputs = dataset.make_batch(args.batch_size)
        input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
        with tf.variable_scope(''):
            model.build_graph(input_phs)
    
        for output_dir in (args.output_gif_dir, args.output_png_dir):
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            with open(os.path.join(output_dir, "options.json"), "w") as f:
                f.write(json.dumps(vars(args), sort_keys = True, indent = 4))
            with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
                f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4))
            with open(os.path.join(output_dir, "model_hparams.json"), "w") as f:
                f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
    
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac)
        config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True)
        sess = tf.Session(config = config)
        sess.graph.as_default()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        model.restore(sess, args.checkpoint)
    
        sample_ind = 0
        gen_images_all = []
        input_images_all = []
    
        while True:
            print("Sample id", sample_ind)
            gen_images_stochastic = []
            if args.num_samples and sample_ind >= args.num_samples:
                break
            try:
                input_results = sess.run(inputs)
                input_images = input_results["images"]
                input_images_all.extend(input_images)
                with open(os.path.join(args.output_png_dir, "input_images_all"), "wb") as input_files:
                    pickle.dump(list(input_images_all), input_files)
    
            except tf.errors.OutOfRangeError:
                break
    
            feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
            for stochastic_sample_ind in range(args.num_stochastic_samples):
                gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)
                gen_images_stochastic.append(gen_images)
                print("Stochastic_sample,", stochastic_sample_ind)
                for i in range(args.batch_size):
                    print("batch", i)
                    #colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
                    cmap_name = 'my_list'
                    if sample_ind < 20 and i == 1:
                        name = 'Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str(
                            sample_ind) + " + Sample_" + str(i)
                        gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :]))
                        #gen_images_ =  gen_images[i, :]
                        input_images_ = input_images[i, :]
                        input_gen_diff = (input_images_[:, :, :,0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922) - (gen_images_[:, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922)
    
                        gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in
                                        range(sequence_length)]  # return the list with 10 (sequence) mse
    
                        fig = plt.figure(figsize=(18,6))
                        gs = gridspec.GridSpec(1, 10)
                        gs.update(wspace = 0., hspace = 0.)
                        ts = [0,5,9,10,12,14,16,18,19]
                        xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
                        ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
    
                        for t in range(len(ts)):
                            #if t==0 : ax1=plt.subplot(gs[t])
                            ax1 = plt.subplot(gs[t])
                            input_image = input_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
                            plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300)
                            ax1.title.set_text("t = " + str(ts[t]+1))
                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
    
                            if t == 0:
                                plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels)
                                plt.ylabel("Ground Truth", fontsize=10)
                        plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg"))
                        plt.clf()
    
                        fig = plt.figure(figsize=(12,6))
                        gs = gridspec.GridSpec(1, 10)
                        gs.update(wspace = 0., hspace = 0.)
                        ts = [10,12,14,16,18,19]
                        for t in range(len(ts)):
                            #if t==0 : ax1=plt.subplot(gs[t])
                            ax1 = plt.subplot(gs[t])
                            gen_image = gen_images_[ts[t], :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922
                            plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300)
                            ax1.title.set_text("t = " + str(ts[t]+1))
                            plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
    
                        plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg"))
                        plt.clf()
    
                        # fig = plt.figure()
                        # gs = gridspec.GridSpec(4,6)
                        # gs.update(wspace = 0.7,hspace=0.8)
                        # ax1 = plt.subplot(gs[0:2,0:3])
                        # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1)
                        # ax3 = plt.subplot(gs[2:4,0:3])
                        # ax4 = plt.subplot(gs[2:4,3:])
                        # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))]
                        # ylabels = [round(i,2) for i  in list(np.linspace(np.max(lat),np.min(lat),5))]
                        # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels)
                        # ax1.title.set_text("(a) Ground Truth")
                        # ax2.title.set_text("(b) SAVP")
                        # ax3.title.set_text("(c) Diff.")
                        # ax4.title.set_text("(d) MSE")
                        #
                        # ax1.xaxis.set_tick_params(labelsize=7)
                        # ax1.yaxis.set_tick_params(labelsize = 7)
                        # ax2.xaxis.set_tick_params(labelsize=7)
                        # ax2.yaxis.set_tick_params(labelsize = 7)
                        # ax3.xaxis.set_tick_params(labelsize=7)
                        # ax3.yaxis.set_tick_params(labelsize = 7)
                        #
                        # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2]))
                        # print("inti images shape", init_images.shape)
                        # xdata, ydata = [], []
                        # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
                        # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1)
                        # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
                        # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300)
                        # #x = np.linspace(0, 64, 64)
                        # #y = np.linspace(0, 64, 64)
                        # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
                        # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images))
                        # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7)
                        # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7)
                        #
                        # cm = LinearSegmentedColormap.from_list(
                        #     cmap_name, "bwr", N = 5)
                        #
                        # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r',
                        # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm)  # cmap = 'PuBu_r',
                        # plot4, = ax4.plot([], [], color = "r")
                        # ax4.set_xlim(0, future_length-1)
                        # ax4.set_ylim(0, 20)
                        # #ax4.set_ylim(0, 0.5)
                        # ax4.set_xlabel("Frames", fontsize=10)
                        # #ax4.set_ylabel("MSE", fontsize=10)
                        # ax4.xaxis.set_tick_params(labelsize=7)
                        # ax4.yaxis.set_tick_params(labelsize=7)
                        #
                        #
                        # plots = [plot1, plot2, plot3, plot4]
                        #
                        # #fig.colorbar(plots[1], ax = [ax1, ax2])
                        #
                        # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7)
                        # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7)
                        # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7)
                        #
                        # def animation_sample(t):
                        #     input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
                        #     gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922
                        #     diff_image = input_gen_diff[t,:,:]
                        #     # p = sns.lineplot(x=x,y=data,color="b")
                        #     # p.tick_params(labelsize=17)
                        #     # plt.setp(p.lines, linewidth=6)
                        #     plots[0].set_data(input_image)
                        #     plots[1].set_data(gen_image)
                        #     #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
                        #     #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images))
                        #     plots[2].set_data(diff_image)
                        #
                        #     if t >= future_length:
                        #         #data = gen_mse_avg_[:t + 1]
                        #         # x = list(range(len(gen_mse_avg_)))[:t+1]
                        #         xdata.append(t-future_length)
                        #         print("xdata", xdata)
                        #         ydata.append(gen_mse_avg_[t])
                        #         print("ydata", ydata)
                        #         plots[3].set_data(xdata, ydata)
                        #         fig.suptitle("Predicted Frame " + str(t-future_length))
                        #     else:
                        #         #plots[3].set_data(xdata, ydata)
                        #         fig.suptitle("Context Frame " + str(t))
                        #     return plots
                        #
                        # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000,
                        #                               repeat_delay=2000)
                        # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4"))
    
                    else:
                        pass
    
    
    
            if sample_ind == 0:
                gen_images_all = gen_images_stochastic
            else:
                gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1)
    
            if args.num_stochastic_samples == 1:
                with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files:
                    pickle.dump(list(gen_images_all[0]), gen_files)
            else:
                with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files:
                    pickle.dump(list(gen_images_stochastic), gen_files)
                with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files:
                    pickle.dump(list(gen_images_all), gen_files)
    
    
    
    
            sample_ind += args.batch_size
    
    
        #         # for i, gen_mse_avg_ in enumerate(gen_mse_avg):
        #         #     ims = []
        #         #     fig = plt.figure()
        #         #     plt.xlim(0,len(gen_mse_avg_))
        #         #     plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg))
        #         #     plt.xlabel("Frames")
        #         #     plt.ylabel("MSE_AVG")
        #         #     #X = list(range(len(gen_mse_avg_)))
        #         #     #for t, gen_mse_avg_ in enumerate(gen_mse_avg):
        #         #     def animate_metric(j):
        #         #         data = gen_mse_avg_[:(j+1)]
        #         #         x = list(range(len(gen_mse_avg_)))[:(j+1)]
        #         #         p = sns.lineplot(x=x,y=data,color="b")
        #         #         p.tick_params(labelsize=17)
        #         #         plt.setp(p.lines, linewidth=6)
        #         #     ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000)
        #         #     ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif"))
        #         #
        #         #
        #         # for i, input_images_ in enumerate(input_images):
        #         #     #context_images_ = (input_results['images'][i])
        #         #     #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
        #         #     ims = []
        #         #     fig = plt.figure()
        #         #     for t, input_image in enumerate(input_images_):
        #         #         im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')
        #         #         ttl = plt.text(1.5, 2,"Frame_" + str(t))
        #         #         ims.append([im,ttl])
        #         #     ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000)
        #         #     ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif"))
        #         #     #plt.show()
        #         #
        #         # for i,gen_images_ in enumerate(gen_images):
        #         #     ims = []
        #         #     fig = plt.figure()
        #         #     for t, gen_image in enumerate(gen_images_):
        #         #         im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
        #         #         ttl = plt.text(1.5, 2, "Frame_" + str(t))
        #         #         ims.append([im, ttl])
        #         #     ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
        #         #     ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif"))
        #
        #
        #             # for i, gen_images_ in enumerate(gen_images):
        #             #     #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8)
        #             #     #gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
        #             #     #bing
        #             #     context_images_ = (input_results['images'][i])
        #             #     gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind)
        #             #     context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_)
        #             #     plt.figure(figsize = (10,2))
        #             #     gs = gridspec.GridSpec(2,10)
        #             #     gs.update(wspace=0.,hspace=0.)
        #             #     for t, gen_image in enumerate(gen_images_):
        #             #         gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1)))
        #             #         gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
        #             #         plt.subplot(gs[t])
        #             #         plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none')  # the last index sets the channel. 0 = t2
        #             #         # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
        #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
        #             #                         right = False, labelbottom = False, labelleft = False)
        #             #         if t == 0: plt.ylabel('Actual', fontsize = 10)
        #             #
        #             #         plt.subplot(gs[t + 10])
        #             #         plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none')
        #             #         # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet)
        #             #         plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False,
        #             #                         right = False, labelbottom = False, labelleft = False)
        #             #         if t == 0: plt.ylabel('Predicted', fontsize = 10)
        #             #     plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png')
        #             #     plt.clf()
        #
        #             # if args.gif_length:
        #             #     context_and_gen_images = context_and_gen_images[:args.gif_length]
        #             # save_gif(os.path.join(args.output_gif_dir, gen_images_fname),
        #             #          context_and_gen_images, fps=args.fps)
        #             #
        #             # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1)))
        #             # for t, gen_image in enumerate(gen_images_):
        #             #     gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t)
        #             #     if gen_image.shape[-1] == 1:
        #             #       gen_image = np.tile(gen_image, (1, 1, 3))
        #             #     else:
        #             #       gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR)
        #             #     cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image)
    
    
    
        with open(os.path.join(args.output_png_dir, "input_images_all"),"rb") as input_files:
            input_images_all = pickle.load(input_files)
    
        with open(os.path.join(args.output_png_dir, "gen_images_all"),"rb") as gen_files:
            gen_images_all = pickle.load(gen_files)
    
        input_images_all = np.array(input_images_all)
        input_images_all = np.array(input_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
        if len(np.array(gen_images_all).shape) == 6:
            for i in range(len(gen_images_all)):
                gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:]
                gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
                mse_all = []
                psnr_all = []
                ssim_all = []
                f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w')
                for i in range(future_length):
                    mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :,
                                                                                0]) ** 2)  # look at all timesteps except the first
                    psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0])
                    ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0],
                                      data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min(
                                          input_images_all[:, i + 10, :, :, 0].flatten()))
                    mse_all.extend([mse_model])
                    psnr_all.extend([psnr_model])
                    ssim_all.extend([ssim_model])
                    results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
                    f.write("##########Predicted Frame {}\n".format(str(i + 1)))
                    f.write("Model MSE: %f\n" % mse_model)
                    # f.write("Previous Frame MSE: %f\n" % mse_prev)
                    f.write("Model PSNR: %f\n" % psnr_model)
                    f.write("Model SSIM: %f\n" % ssim_model)
    
                pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb"))
                # f.write("Previous frame PSNR: %f\n" % psnr_prev)
                f.write("Shape of X_test: " + str(input_images_all.shape))
                f.write("")
                f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape))
    
        else:
            gen_images_all = np.array(gen_images_all) * (321.46630859375 - 235.2141571044922) + 235.2141571044922
            # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2)  # look at all timesteps except the first
            # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2)
            # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 )
            mse_all = []
            psnr_all = []
            ssim_all = []
            f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w')
            for i in range(future_length):
                mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :,
                                                                            0]) ** 2)  # look at all timesteps except the first
                psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0])
                ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0],
                                  data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min(
                                      input_images_all[:, i + 10, :, :, 0].flatten()))
                mse_all.extend([mse_model])
                psnr_all.extend([psnr_model])
                ssim_all.extend([ssim_model])
                results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all}
                f.write("##########Predicted Frame {}\n".format(str(i + 1)))
                f.write("Model MSE: %f\n" % mse_model)
                # f.write("Previous Frame MSE: %f\n" % mse_prev)
                f.write("Model PSNR: %f\n" % psnr_model)
                f.write("Model SSIM: %f\n" % ssim_model)
    
            pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb"))
            # f.write("Previous frame PSNR: %f\n" % psnr_prev)
            f.write("Shape of X_test: " + str(input_images_all.shape))
            f.write("")
            f.write("Shape of X_hat: " + str(gen_images_all.shape))
    
    
    
        psnr_model = psnr(input_images_all[:, :10, :, :, 0],  gen_images_all[:, :10, :, :, 0])
        psnr_model_last = psnr(input_images_all[:, 10, :, :, 0],  gen_images_all[:,10, :, :, 0])
        #psnr_prev = psnr(input_images_all[:, :, :, :, 0],  input_images_all[:, 1:10, :, :, 0])
    
        # ims = []
        # fig = plt.figure()
        # for frame in range(20):
        #     input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel)
        #     #pix_mean = np.mean(input_gen_diff, axis = 0)
        #     #pix_std = np.std(input_gen_diff, axis=0)
        #     im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu')
        #     if frame == 0:
        #         fig.colorbar(im)
        #     ttl = plt.text(1.5, 2, "Frame_" + str(frame +1))
        #     ims.append([im, ttl])
        # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000)
        # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4"))
        # plt.close("all")
    
        # ims = []
        # fig = plt.figure()
        # for frame in range(19):
        #     pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0]  # Get the first prediction frame (batch,height, width, channel)
        #     #pix_mean = np.mean(input_gen_diff, axis = 0)
        #     #pix_std = np.std(input_gen_diff, axis=0)
        #     im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu')
        #     if frame == 0:
        #         fig.colorbar(im)
        #     ttl = plt.text(1.5, 2, "Frame_" + str(frame+1))
        #     ims.append([im, ttl])
        # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000)
        # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4"))
    
        # seed(1)
        # s = random.sample(range(len(gen_images_all)), 100)
        # print("******KDP******")
        # #kernel density plot for checking the model collapse
        # fig = plt.figure()
        # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images")
        # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True")
        # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability')
        # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400)
        # plt.clf()
    
        #line plot for evaluating the prediction and groud-truth
        # for i in [0,3,6,9,12,15,18]:
        #     fig = plt.figure()
        #     plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3)
        #     #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3)
        #     plt.xlabel("Prediction")
        #     plt.ylabel("Real values")
        #     plt.title("Frame_{}".format(i+1))
        #     plt.plot([250,300], [250,300],color="black")
        #     plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i))))
        #     plt.clf()
        #
        # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence)
        # x = [str(i+1) for i in list(range(19))]
        # fig,axis = plt.subplots()
        # mean_f = np.mean(mse_model_by_frames, axis = 0)
        # median = np.median(mse_model_by_frames, axis=0)
        # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0)
        # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0)
        # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0)
        # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0)
        # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range")
        # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range")
        # plt.plot(x, median, color="grey", linewidth=0.6, label="Median")
        # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean")
        # plt.title(f'MSE percentile')
        # plt.xlabel("Frames")
        # plt.legend(loc=2, fontsize=8)
        # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png"))
    
    if __name__ == '__main__':
        main()