diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh new file mode 100755 index 0000000000000000000000000000000000000000..a81e9a1499ce2619c6d934d32396c7128bd6b565 --- /dev/null +++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh @@ -0,0 +1,37 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataPreprocess_to_tf-out.%j +#SBATCH --error=DataPreprocess_to_tf-err.%j +#SBATCH --time=00:20:00 +#SBATCH --partition=devel +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de + + +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) + +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist + +# run Preprocessing (step 2 where Tf-records are generated) +srun python ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}/tfrecords diff --git a/video_prediction_savp/HPC_scripts/generate_movingmnist.sh b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh new file mode 100755 index 0000000000000000000000000000000000000000..1de81d2543d255a160ff811ff391a963ef712bde --- /dev/null +++ b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh @@ -0,0 +1,44 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=generate_era5-out.%j +#SBATCH --error=generate_era5-err.%j +#SBATCH --time=00:20:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=develgpus +#SBATCH --mail-type=ALL +#SBATCH --mail-user=s.stadtler@fz-juelich.de +##jutil env activate -p cjjsc42 + +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist +checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist +results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist +# name of model +model=convLSTM + +# run postprocessing/generation of model results including evaluation metrics +srun python -u ../scripts/generate_movingmnist.py \ +--input_dir ${source_dir}/ --dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model} \ +--mode test --model ${model} --results_dir ${results_dir}/${model} --batch_size 2 --dataset era5 > generate_era5-out.out + +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh index a6c24a062ca30b06879641806d771beacc4b34f8..34894da8c3e345955c05b69994f4f3cf431174ff 100644 --- a/video_prediction_savp/HPC_scripts/hyperparam_setup.sh +++ b/video_prediction_savp/HPC_scripts/hyperparam_setup.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# for choosing the model +# for choosing the model convLSTM vae mcnet savp export model=convLSTM export model_hparams=../hparams/era5/${model}/model_hparams.json diff --git a/video_prediction_savp/HPC_scripts/train_movingmnist.sh b/video_prediction_savp/HPC_scripts/train_movingmnist.sh new file mode 100755 index 0000000000000000000000000000000000000000..006ff73c30c4a53c80aef9371bfbe29fac39f973 --- /dev/null +++ b/video_prediction_savp/HPC_scripts/train_movingmnist.sh @@ -0,0 +1,47 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=train_moving_mnist-out.%j +#SBATCH --error=train_moving_mnist-err.%j +#SBATCH --time=00:20:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=gpus +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de +##jutil env activate -p cjjsc42 + + +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) + +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist + +# for choosing the model, convLSTM,savp, mcnet,vae +model=convLSTM +dataset=moving_mnist +model_hparams=../hparams/${dataset}/${model}/model_hparams.json +destination_dir=${destination_dir}/${model}/"$(date +"%Y%m%dT%H%M")_"$USER"" + +# rund training + +srun python ../scripts/train_dummy.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/ diff --git a/video_prediction_savp/env_setup/requirements.txt b/video_prediction_savp/env_setup/requirements.txt index 4bf2f0b25d082c4c503bbd56f46d28360a48df43..173b8a10c8dec1d8186adc84c144b79863406d3f 100644 --- a/video_prediction_savp/env_setup/requirements.txt +++ b/video_prediction_savp/env_setup/requirements.txt @@ -1,4 +1,4 @@ -opencv-python +opencv-python==4.2.0.34 scipy scikit-image pandas diff --git a/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json index c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8..fde951edd2e6b41965fbdce6ce831c1e154cbd0e 100644 --- a/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json +++ b/video_prediction_savp/hparams/era5/convLSTM/model_hparams.json @@ -1,10 +1,11 @@ { - "batch_size": 10, + "batch_size": 4, "lr": 0.001, - "max_epochs":2, + "max_epochs":20, "context_frames":10, - "sequence_length":20 + "sequence_length":20, + "loss_fun":"rmse" } diff --git a/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json new file mode 100644 index 0000000000000000000000000000000000000000..b59f6cb2ee96162b2eb6014d7ca6bd37f54d4218 --- /dev/null +++ b/video_prediction_savp/hparams/moving_mnist/convLSTM/model_hparams.json @@ -0,0 +1,12 @@ + +{ + "batch_size": 10, + "lr": 0.001, + "max_epochs":20, + "context_frames":10, + "sequence_length":20, + "loss_fun":"cross_entropy" +} + + + diff --git a/video_prediction_savp/scripts/generate_movingmnist.py b/video_prediction_savp/scripts/generate_movingmnist.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fbf5eb5d8d8f4cad87ae26d15bc2787d9e6c0a --- /dev/null +++ b/video_prediction_savp/scripts/generate_movingmnist.py @@ -0,0 +1,822 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import math +import random +import cv2 +import numpy as np +import tensorflow as tf +import pickle +from random import seed +import random +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.animation as animation +import pandas as pd +import re +from video_prediction import datasets, models +from matplotlib.colors import LinearSegmentedColormap +#from matplotlib.ticker import MaxNLocator +#from video_prediction.utils.ffmpeg_gif import save_gif +from skimage.metrics import structural_similarity as ssim +import datetime +# Scarlet 2020/05/28: access to statistical values in json file +from os import path +import sys +sys.path.append(path.abspath('../video_prediction/datasets/')) +from era5_dataset_v2 import Norm_data +from os.path import dirname +from netCDF4 import Dataset,date2num +from metadata import MetaData as MetaData + +def set_seed(seed): + if seed is not None: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def get_coordinates(metadata_fname): + """ + Retrieves the latitudes and longitudes read from the metadata json file. + """ + md = MetaData(json_file=metadata_fname) + md.get_metadata_from_file(metadata_fname) + + try: + print("lat:",md.lat) + print("lon:",md.lon) + return md.lat, md.lon + except: + raise ValueError("Error when handling: '"+metadata_fname+"'") + + +def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model): + if checkpoint: + checkpoint_dir = os.path.normpath(checkpoint) + if not os.path.isdir(checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % checkpoint) + options = json.loads(f.read()) + dataset = dataset or options['dataset'] + model = model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + else: + if not dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not model: + raise ValueError('model is required when checkpoint is not specified') + + return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict + + + +def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict): + VideoDataset = datasets.get_dataset_class(dataset) + dataset = VideoDataset( + input_dir, + mode = mode, + num_epochs = num_epochs, + seed = seed, + hparams_dict = dataset_hparams_dict, + hparams = dataset_hparams) + return dataset + + +def setup_dirs(input_dir,results_png_dir): + input_dir = args.input_dir + temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/" + print ("temporal_dir:",temporal_dir) + + +def update_hparams_dict(model_hparams_dict,dataset): + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + return hparams_dict + + +def psnr(img1, img2): + mse = np.mean((img1 - img2) ** 2) + if mse == 0: return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + +def setup_num_samples_per_epoch(num_samples, dataset): + if num_samples: + if num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + #if num_examples_per_epoch % args.batch_size != 0: + # raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + return num_examples_per_epoch + + +def initia_save_data(): + sample_ind = 0 + gen_images_all = [] + #Bing:20200410 + persistent_images_all = [] + input_images_all = [] + return sample_ind, gen_images_all,persistent_images_all, input_images_all + + +def write_params_to_results_dir(args,output_dir,dataset,model): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) + return None + +def get_one_seq_and_time(input_images,i): + assert (len(np.array(input_images).shape)==5) + input_images_ = input_images[i,:,:,:,:] + return input_images_ + + +def denorm_images_all_channels(input_images_): + input_images_all_channles_denorm = [] + input_images_ = np.array(input_images_) + input_images_denorm = input_images_ * 255.0 + #print("input_images_denorm shape",input_images_denorm.shape) + return input_images_denorm + +def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)") + img_len = imgs.shape[0] + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + for i in range(img_len): + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet') + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + plt.savefig(os.path.join(output_png_dir, label + "_" + str(idx) + ".jpg")) + print("images_saved") + plt.clf() + + + +def get_persistence(ts): + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type = str, required = True, + help = "either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type = str, default = 'results', + help = "ignored if output_gif_dir is specified") + parser.add_argument("--checkpoint", + help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val', + help = 'mode for dataset, val or test.') + parser.add_argument("--dataset", type = str, help = "dataset class name") + parser.add_argument("--dataset_hparams", type = str, + help = "a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type = str, help = "model class name") + parser.add_argument("--model_hparams", type = str, + help = "a string of comma separated list of model hyperparameters") + parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch") + parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type = int, default = 1) + parser.add_argument("--num_stochastic_samples", type = int, default = 1) + parser.add_argument("--gif_length", type = int, help = "default is sequence_length") + parser.add_argument("--fps", type = int, default = 4) + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") + parser.add_argument("--seed", type = int, default = 7) + args = parser.parse_args() + set_seed(args.seed) + + dataset_hparams_dict = {} + model_hparams_dict = {} + + options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model) + print("Step 1 finished") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + #setup dataset and model object + input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored + dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) + + print("Step 2 finished") + VideoPredictionModel = models.get_model_class(model) + + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + + model = VideoPredictionModel( + mode = args.mode, + hparams_dict = hparams_dict, + hparams = args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames #context_Frames is the number of input frames + + num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset) + + inputs = dataset.make_batch(args.batch_size) + print("inputs",inputs) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + print("input_phs",input_phs) + + + # Build graph + with tf.variable_scope(''): + model.build_graph(input_phs) + + #Write the update hparameters into results_dir + write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True) + sess = tf.Session(config = config) + sess.graph.as_default() + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) + + #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data + sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() + + is_first=True + #loop for in samples + while sample_ind < 5: + gen_images_stochastic = [] + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_images = input_results["images"] + #get the intial times + t_starts = input_results["T_start"] + except tf.errors.OutOfRangeError: + break + + #Get prediction values + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] + print("gen_images 20200822:",np.array(gen_images).shape) + #Loop in batch size + for i in range(args.batch_size): + + #get one seq and the corresponding start time point + input_images_ = get_one_seq_and_time(input_images,i) + + #Renormalized data for inputs + input_images_denorm = denorm_images_all_channels(input_images_) + print("input_images_denorm",input_images_denorm[0][0]) + + #Renormalized data for inputs + gen_images_ = gen_images[i] + gen_images_denorm = denorm_images_all_channels(gen_images_) + print("gene_images_denorm:",gen_images_denorm[0][0]) + + #Generate images inputs + plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],idx = sample_ind + i, label="Ground Truth",output_png_dir=args.results_dir) + + #Generate forecast images + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],idx = sample_ind + i,label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + + #TODO: Scaret plot persistence image + #implment get_persistence() function + + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation + + sample_ind += args.batch_size + + + #for input_image in input_images_: + +# for stochastic_sample_ind in range(args.num_stochastic_samples): +# input_images_all.extend(input_images) +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(input_images_all), input_files) + + +# gen_images_stochastic.append(gen_images) +# #print("Stochastic_sample,", stochastic_sample_ind) +# for i in range(args.batch_size): +# #bing:20200417 +# t_stampe = test_temporal_pkl[sample_ind+i] +# print("timestamp:",type(t_stampe)) +# persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) +# print ("persistent ts",persistent_ts) +# persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) +# persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length] +# print("persistent index in test set:", persistent_idx) +# print("persistent_X.shape",persistent_X.shape) +# persistent_images_all.append(persistent_X) + +# cmap_name = 'my_list' +# if sample_ind < 100: +# #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( +# # sample_ind) + " + Sample_" + str(i) +# name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S") +# print ("name",name) +# gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) +# #gen_images_ = gen_images[i, :] +# input_images_ = input_images[i, :] +# #Bing:20200417 +# #persistent_images = ? +# #+++Scarlet:20200528 +# #print('Scarlet1') +# input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) +# persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) +# #---Scarlet:20200528 +# gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse +# persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse + +# fig = plt.figure(figsize=(18,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# ts = list(range(10,20)) #[10,11,12,..] +# xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +# ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + +# for t in ts: + +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet2') +# input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) +# if t == 0: +# plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) +# plt.ylabel("Ground Truth", fontsize=10) +# plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) +# plt.clf() + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) + +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet3') +# gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) +# plt.clf() + + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) +# plt.clf() + + +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(persistent_images_all), input_files) +# print ("Save persistent all") +# if is_first: +# gen_images_all = gen_images_stochastic +# is_first = False +# else: +# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + +# if args.num_stochastic_samples == 1: +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: +# pickle.dump(list(gen_images_all[0]), gen_files) +# print ("Save generate all") +# else: +# with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: +# pickle.dump(list(gen_images_stochastic), gen_files) +# with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: +# pickle.dump(list(gen_images_all), gen_files) + +# sample_ind += args.batch_size + + +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files: +# input_images_all = pickle.load(input_files) + +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files: +# gen_images_all = pickle.load(gen_files) + +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: +# persistent_images_all = pickle.load(gen_files) + +# #+++Scarlet:20200528 +# #print('Scarlet4') +# input_images_all = np.array(input_images_all) +# input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) +# #---Scarlet:20200528 +# persistent_images_all = np.array(persistent_images_all) +# if len(np.array(gen_images_all).shape) == 6: +# for i in range(len(gen_images_all)): +# #+++Scarlet:20200528 +# #print('Scarlet5') +# gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] +# gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) +# #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# #---Scarlet:20200528 +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0], +# data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) + +# else: +# #+++Scarlet:20200528 +# #print('Scarlet6') +# gen_images_all = np.array(gen_images_all) +# gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) +# #---Scarlet:20200528 + +# # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first +# # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) +# # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# persistent_mse_all = [] +# persistent_psnr_all = [] +# persistent_ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first + +# psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0]) +# persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# persistent_mse_all.extend([persistent_mse_model]) +# persistent_psnr_all.extend([persistent_psnr_model]) +# persistent_ssim_all.extend([persistent_ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} + +# persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb")) +# pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all.shape) + +if __name__ == '__main__': + main() + + #psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + #psnr_prev = psnr(input_images_all[:, :, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + + # ims = [] + # fig = plt.figure() + # for frame in range(20): + # input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000) + # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + # plt.close("all") + + # ims = [] + # fig = plt.figure() + # for frame in range(19): + # pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + # seed(1) + # s = random.sample(range(len(gen_images_all)), 100) + # print("******KDP******") + # #kernel density plot for checking the model collapse + # fig = plt.figure() + # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + # plt.clf() + + #line plot for evaluating the prediction and groud-truth + # for i in [0,3,6,9,12,15,18]: + # fig = plt.figure() + # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + # #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + # plt.xlabel("Prediction") + # plt.ylabel("Real values") + # plt.title("Frame_{}".format(i+1)) + # plt.plot([250,300], [250,300],color="black") + # plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + # plt.clf() + # + # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + # x = [str(i+1) for i in list(range(19))] + # fig,axis = plt.subplots() + # mean_f = np.mean(mse_model_by_frames, axis = 0) + # median = np.median(mse_model_by_frames, axis=0) + # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + # plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + # plt.title(f'MSE percentile') + # plt.xlabel("Frames") + # plt.legend(loc=2, fontsize=8) + # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + + +## +## +## # fig = plt.figure() +## # gs = gridspec.GridSpec(4,6) +## # gs.update(wspace = 0.7,hspace=0.8) +## # ax1 = plt.subplot(gs[0:2,0:3]) +## # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) +## # ax3 = plt.subplot(gs[2:4,0:3]) +## # ax4 = plt.subplot(gs[2:4,3:]) +## # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +## # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] +## # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) +## # ax1.title.set_text("(a) Ground Truth") +## # ax2.title.set_text("(b) SAVP") +## # ax3.title.set_text("(c) Diff.") +## # ax4.title.set_text("(d) MSE") +## # +## # ax1.xaxis.set_tick_params(labelsize=7) +## # ax1.yaxis.set_tick_params(labelsize = 7) +## # ax2.xaxis.set_tick_params(labelsize=7) +## # ax2.yaxis.set_tick_params(labelsize = 7) +## # ax3.xaxis.set_tick_params(labelsize=7) +## # ax3.yaxis.set_tick_params(labelsize = 7) +## # +## # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) +## # print("inti images shape", init_images.shape) +## # xdata, ydata = [], [] +## # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # #x = np.linspace(0, 64, 64) +## # #y = np.linspace(0, 64, 64) +## # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) +## # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) +## # +## # cm = LinearSegmentedColormap.from_list( +## # cmap_name, "bwr", N = 5) +## # +## # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r', +## # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm) # cmap = 'PuBu_r', +## # plot4, = ax4.plot([], [], color = "r") +## # ax4.set_xlim(0, future_length-1) +## # ax4.set_ylim(0, 20) +## # #ax4.set_ylim(0, 0.5) +## # ax4.set_xlabel("Frames", fontsize=10) +## # #ax4.set_ylabel("MSE", fontsize=10) +## # ax4.xaxis.set_tick_params(labelsize=7) +## # ax4.yaxis.set_tick_params(labelsize=7) +## # +## # +## # plots = [plot1, plot2, plot3, plot4] +## # +## # #fig.colorbar(plots[1], ax = [ax1, ax2]) +## # +## # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) +## # +## # def animation_sample(t): +## # input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # diff_image = input_gen_diff[t,:,:] +## # # p = sns.lineplot(x=x,y=data,color="b") +## # # p.tick_params(labelsize=17) +## # # plt.setp(p.lines, linewidth=6) +## # plots[0].set_data(input_image) +## # plots[1].set_data(gen_image) +## # #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # plots[2].set_data(diff_image) +## # +## # if t >= future_length: +## # #data = gen_mse_avg_[:t + 1] +## # # x = list(range(len(gen_mse_avg_)))[:t+1] +## # xdata.append(t-future_length) +## # print("xdata", xdata) +## # ydata.append(gen_mse_avg_[t]) +## # print("ydata", ydata) +## # plots[3].set_data(xdata, ydata) +## # fig.suptitle("Predicted Frame " + str(t-future_length)) +## # else: +## # #plots[3].set_data(xdata, ydata) +## # fig.suptitle("Context Frame " + str(t)) +## # return plots +## # +## # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, +## # repeat_delay=2000) +## # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) +## +#### else: +#### pass +## + + + + + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): + # # ims = [] + # # fig = plt.figure() + # # plt.xlim(0,len(gen_mse_avg_)) + # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) + # # plt.xlabel("Frames") + # # plt.ylabel("MSE_AVG") + # # #X = list(range(len(gen_mse_avg_))) + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): + # # def animate_metric(j): + # # data = gen_mse_avg_[:(j+1)] + # # x = list(range(len(gen_mse_avg_)))[:(j+1)] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) + # # + # # + # # for i, input_images_ in enumerate(input_images): + # # #context_images_ = (input_results['images'][i]) + # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # ims = [] + # # fig = plt.figure() + # # for t, input_image in enumerate(input_images_): + # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) + # # ims.append([im,ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) + # # #plt.show() + # # + # # for i,gen_images_ in enumerate(gen_images): + # # ims = [] + # # fig = plt.figure() + # # for t, gen_image in enumerate(gen_images_): + # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) + # # ims.append([im, ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) + # + # + # # for i, gen_images_ in enumerate(gen_images): + # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + # # #bing + # # context_images_ = (input_results['images'][i]) + # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + # # plt.figure(figsize = (10,2)) + # # gs = gridspec.GridSpec(2,10) + # # gs.update(wspace=0.,hspace=0.) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # plt.subplot(gs[t]) + # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 + # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Actual', fontsize = 10) + # # + # # plt.subplot(gs[t + 10]) + # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Predicted', fontsize = 10) + # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') + # # plt.clf() + # + # # if args.gif_length: + # # context_and_gen_images = context_and_gen_images[:args.gif_length] + # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + # # context_and_gen_images, fps=args.fps) + # # + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # if gen_image.shape[-1] == 1: + # # gen_image = np.tile(gen_image, (1, 1, 3)) + # # else: + # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 0e250b47df28d115c8cdfc77fc708eab5e094ce6..1fcbd1cf97442f1ea440039a0bb6769473b957f3 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -202,11 +202,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t ts_len = len(ts) ts_input = ts[:context_frames] ts_forecast = ts[context_frames:] - #print("context_frame:",context_frames) - #print("future_frame",future_length) - #print("length of ts input:",len(ts_input)) - - print("input_images_ shape in netcdf,",input_images_.shape) gen_images_ = np.array(gen_images_) output_file = os.path.join(output_dir,fl_name) @@ -289,18 +284,19 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t #Temperature: t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) t2.units = 'K' - t2[:,:,:] = gen_images_[context_frames:,:,:,0] + print ("gen_images_ 20200822:",np.array(gen_images_).shape) + t2[:,:,:] = gen_images_[context_frames-1:,:,:,0] print("NetCDF created") #mean sea level pressure msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) msl.units = 'Pa' - msl[:,:,:] = gen_images_[context_frames:,:,:,1] + msl[:,:,:] = gen_images_[context_frames-1:,:,:,1] #Geopotential at 500 gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) gph500.units = 'm' - gph500[:,:,:] = gen_images_[context_frames:,:,:,2] + gph500[:,:,:] = gen_images_[context_frames-1:,:,:,2] print("{} created".format(output_file)) @@ -451,7 +447,8 @@ def main(): #Get prediction values feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] - + assert gen_images.shape[1] == sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth + print("gen_images 20200822:",np.array(gen_images).shape) #Loop in batch size for i in range(args.batch_size): @@ -459,26 +456,26 @@ def main(): input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i) #generate time stamps for sequences ts = generate_seq_timestamps(t_start,len_seq=sequence_length) - + #Renormalized data for inputs stat_fl = os.path.join(args.input_dir,"pickle/statistics.json") input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"]) - print("input_images_denorm",input_images_denorm[0][0]) + print("input_images_denorm shape",np.array(input_images_denorm).shape) #Renormalized data for inputs gen_images_ = gen_images[i] gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"]) - print("gene_images_denorm:",gen_images_denorm[0][0]) + print("gene_images_denorm shape",np.array(gen_images_denorm).shape) #Save input to netCDF file init_date_str = ts[0].strftime("%Y%m%d%H") save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str)) #Generate images inputs - plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir) + plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Ground Truth",output_png_dir=args.results_dir) #Generate forecast images - plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) #TODO: Scaret plot persistence image #implment get_persistence() function diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index d4db02c3804c445ee434d2fbd4de01ae4ae2dd29..0417a36514fd6136fb9fbe934bfb396633fa6093 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -16,13 +16,6 @@ from json import JSONEncoder import pickle as pkl -class NumpyArrayEncoder(JSONEncoder): - def default(self, obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - return JSONEncoder.default(self, obj) - - def add_tag_suffix(summary, tag_suffix): summary_proto = tf.Summary() summary_proto.ParseFromString(summary) @@ -80,7 +73,6 @@ def set_seed(seed): random.seed(seed) def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model): - model_hparams_dict_load = {} if model_hparams_dict: with open(model_hparams_dict) as f: @@ -159,8 +151,19 @@ def make_dataset_iterator(train_dataset, val_dataset, batch_size ): return inputs,train_handle, val_handle -def plot_train(train_losses,val_losses,output_dir): - iterations = list(range(len(train_losses))) +def plot_train(train_losses,val_losses,step,output_dir): + """ + Function to plot training losses for train and val datasets against steps + params: + train_losses/val_losses (list): train losses, which length should be equal to the number of training steps + step (int): current training step + output_dir (str): the path to save the plot + + return: None + """ + + iterations = list(range(len(train_losses))) + if len(train_losses) != len(val_losses) or len(train_losses) != step +1 : raise ValueError("The length of training losses must be equal to the length of val losses and step +1 !") plt.plot(iterations, train_losses, 'g', label='Training loss') plt.plot(iterations, val_losses, 'b', label='validation loss') plt.title('Training and Validation loss') @@ -168,6 +171,8 @@ def plot_train(train_losses,val_losses,output_dir): plt.ylabel('Loss') plt.legend() plt.savefig(os.path.join(output_dir,'plot_train.png')) + plt.close() + return None def save_results_to_dict(results_dict,output_dir): with open(os.path.join(output_dir,"results.json"),"w") as fp: @@ -232,7 +237,9 @@ def main(): inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size) #build model graph - del inputs["T_start"] + #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, otherwise the model will raise error + if args.dataset == "era5": + del inputs["T_start"] model.build_graph(inputs) #save all the model, data params to output dirctory @@ -255,6 +262,7 @@ def main(): num_examples_per_epoch = train_dataset.num_examples_per_epoch() print ("number of exmaples per epoch:",num_examples_per_epoch) steps_per_epoch = int(num_examples_per_epoch/batch_size) + #number of steps totally equal to the number of steps per each echo multiple by number of epochs total_steps = steps_per_epoch * max_epochs global_step = tf.train.get_or_create_global_step() #mock total_steps only for fast debugging @@ -276,17 +284,18 @@ def main(): val_losses=[] run_start_time = time.time() for step in range(start_step,total_steps): - #global_step = sess.run(global_step):q - + #global_step = sess.run(global_step) + # +++ Scarlet 20200813 + timeit_start = time.time() + # --- Scarlet 20200813 print ("step:", step) val_handle_eval = sess.run(val_handle) - #Fetch variables in the graph - fetches = {"train_op": model.train_op} #fetches["latent_loss"] = model.latent_loss fetches["summary"] = model.summary_op - + fetches["global_step"] = model.global_step + if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel": fetches["global_step"] = model.global_step fetches["total_loss"] = model.total_loss @@ -322,8 +331,8 @@ def main(): val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) val_losses.append(val_results["total_loss"]) - summary_writer.add_summary(results["summary"]) - summary_writer.add_summary(val_results["summary"]) + summary_writer.add_summary(results["summary"],results["global_step"]) + summary_writer.add_summary(val_results["summary"],results["global_step"]) summary_writer.flush() # global_step will have the correct step count if we resume from a checkpoint @@ -342,16 +351,30 @@ def main(): print ("The model name does not exist") #print("saving model to", args.output_dir) - saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)# + + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step) + # +++ Scarlet 20200813 + timeit_end = time.time() + # --- Scarlet 20200813 + print("time needed for this step", timeit_end - timeit_start, ' s') + if step % 20 == 0: + # I save the pickle file and plot here inside the loop in case the training process cannot finished after job is done. + save_results_to_pkl(train_losses,val_losses,args.output_dir) + plot_train(train_losses,val_losses,step,args.output_dir) + + train_time = time.time() - run_start_time results_dict = {"train_time":train_time, "total_steps":total_steps} save_results_to_dict(results_dict,args.output_dir) - save_results_to_pkl(train_losses, val_losses, args.output_dir) + #save_results_to_pkl(train_losses, val_losses, args.output_dir) print("train_losses:",train_losses) print("val_losses:",val_losses) - plot_train(train_losses,val_losses,args.output_dir) + #plot_train(train_losses,val_losses,args.output_dir) print("Done") + # +++ Scarlet 20200814 + print("Total training time:", train_time/60., "min") + # +++ Scarlet 20200814 if __name__ == '__main__': main() diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py index e449a65bd48b14ef5a11e6846ce4d8f39f7ed193..f58607c2f4c14047aefb36956e98bd228a30aeb1 100644 --- a/video_prediction_savp/video_prediction/datasets/__init__.py +++ b/video_prediction_savp/video_prediction/datasets/__init__.py @@ -7,6 +7,7 @@ from .kth_dataset import KTHVideoDataset from .ucf101_dataset import UCF101VideoDataset from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset_v2 import ERA5Dataset_v2 +from .moving_mnist import MovingMnist #from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): @@ -19,6 +20,7 @@ def get_dataset_class(dataset): 'ucf101': 'UCF101VideoDataset', 'cartgripper': 'CartgripperVideoDataset', "era5":"ERA5Dataset_v2", + "moving_mnist":"MovingMnist" # "era5_anomaly":"ERA5Dataset_v2_anomaly", } dataset_class = dataset_mappings.get(dataset, dataset) diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py index a3c9fc3666eb21ef02f9e5f64c8c95a29034d619..2fb693e4066570a93889850c571654a46fb81e4e 100644 --- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py +++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py @@ -283,46 +283,56 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file #sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') #Bing 2020/07/16 #print ("open intput dir,",input_file) - with open(input_file, "rb") as data_file: - X_train = pickle.load(data_file) - with open(temp_input_file,"rb") as temp_file: - T_train = pickle.load(temp_file) - #print("T_train:",T_train) - #check to make sure the X_train and T_train has the same length - assert (len(X_train) == len(T_train)) + try: + with open(input_file, "rb") as data_file: + X_train = pickle.load(data_file) + with open(temp_input_file,"rb") as temp_file: + T_train = pickle.load(temp_file) + + #print("T_train:",T_train) + #check to make sure the X_train and T_train has the same length + assert (len(X_train) == len(T_train)) + + X_possible_starts = [i for i in range(len(X_train) - seq_length)] + for X_start in X_possible_starts: + X_end = X_start + seq_length + #seq = X_train[X_start:X_end, :, :,:] + seq = X_train[X_start:X_end,:,:,:] + #Recored the start point of the timestamps + T_start = T_train[X_start] + #print("T_start:",T_start) + seq = list(np.array(seq).reshape((seq_length, height, width, nvars))) + if not sequences: + last_start_sequence_iter = sequence_iter + + + sequences.append(seq) + T_start_points.append(T_start) + sequence_iter += 1 + + if len(sequences) == sequences_per_file: + ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables + sequences = np.array(sequences) + ### normalization + for i in range(nvars): + sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm) + + output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1) + output_fname = os.path.join(output_dir, output_fname) + print("T_start_points:",T_start_points) + if os.path.isfile(output_fname): + print(output_fname, ' already exists, skip it') + else: + save_tf_record(output_fname, list(sequences), T_start_points) + T_start_points = [] + sequences = [] + print("Finished for input file",input_file) + #sequence_lengths_file.close() - X_possible_starts = [i for i in range(len(X_train) - seq_length)] - for X_start in X_possible_starts: - X_end = X_start + seq_length - #seq = X_train[X_start:X_end, :, :,:] - seq = X_train[X_start:X_end,:,:,:] - #Recored the start point of the timestamps - T_start = T_train[X_start] - #print("T_start:",T_start) - seq = list(np.array(seq).reshape((seq_length, height, width, nvars))) - if not sequences: - last_start_sequence_iter = sequence_iter - - - sequences.append(seq) - T_start_points.append(T_start) - sequence_iter += 1 - - if len(sequences) == sequences_per_file: - ###Normalization should adpot the selected variables, here we used duplicated channel temperature variables - sequences = np.array(sequences) - ### normalization - for i in range(nvars): - sequences[:,:,:,:,i] = norm_cls.norm_var(sequences[:,:,:,:,i],vars_in[i],norm) - - output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year,month,last_start_sequence_iter,sequence_iter - 1) - output_fname = os.path.join(output_dir, output_fname) - print("T_start_points:",T_start_points) - save_tf_record(output_fname, list(sequences), T_start_points) - T_start_points = [] - sequences = [] - print("Finished for input file",input_file) - #sequence_lengths_file.close() + except FileNotFoundError as fnf_error: + print(fnf_error) + pass + return def write_sequence_file(output_dir,seq_length,sequences_per_file): @@ -361,18 +371,18 @@ def main(): ############################################################ partition = { "train":{ - # "2222":[1,2,3,5,6,7,8,9,10,11,12], - # "2010_1":[1,2,3,4,5,6,7,8,9,10,11,12], - # "2012":[1,2,3,4,5,6,7,8,9,10,11,12], - # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12], - # "2015":[1,2,3,4,5,6,7,8,9,10,11,12], - "2017_test":[1,2,3,4,5,6,7,8,9,10] + # "2222":[1,2,3,5,6,7,8,9,10,11,12], # Issue due to month 04, it is missing + "2010":[1,2,3,4,5,6,7,8,9,10,11,12], + # "2012":[1,2,3,4,5,6,7,8,9,10,11,12], + "2013":[1,2,3,4,5,6,7,8,9,10,11,12], + "2015":[1,2,3,4,5,6,7,8,9,10,11,12], + "2019":[1,2,3,4,5,6,7,8,9,10,11,12] }, "val": - {"2017_test":[11] + {"2017":[1,2,3,4,5,6,7,8,9,10,11,12] }, "test": - {"2017_test":[12] + {"2016":[1,2,3,4,5,6,7,8,9,10,11,12] } } diff --git a/video_prediction_savp/video_prediction/datasets/kth_dataset.py b/video_prediction_savp/video_prediction/datasets/kth_dataset.py index 40fb6bf57b8219fc7e75c3759df9e6b38fffeb30..e1e11d51968e706868fd89f26faa25d1999d3a9b 100644 --- a/video_prediction_savp/video_prediction/datasets/kth_dataset.py +++ b/video_prediction_savp/video_prediction/datasets/kth_dataset.py @@ -136,11 +136,11 @@ def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequence def main(): parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, help="directory containing the processed directories " + parser.add_argument("input_dir", type=str, help="directory containing the processed directories " "boxing, handclapping, handwaving, " "jogging, running, walking") - parser.add_argument("--output_dir", type=str) - parser.add_argument("--image_size", type=int) + parser.add_argument("output_dir", type=str) + parser.add_argument("image_size", type=int) args = parser.parse_args() partition_names = ['train', 'val', 'test'] diff --git a/video_prediction_savp/video_prediction/datasets/moving_mnist.py b/video_prediction_savp/video_prediction/datasets/moving_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..12a517ac9540dd443866d47ea689803abe9e9a4d --- /dev/null +++ b/video_prediction_savp/video_prediction/datasets/moving_mnist.py @@ -0,0 +1,241 @@ +import argparse +import glob +import itertools +import os +import pickle +import random +import re +import numpy as np +import json +import tensorflow as tf +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +# ML 2020/04/14: hack for getting functions of process_netCDF_v2: +from os import path +import sys +sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/')) +import DataPreprocess.process_netCDF_v2 +from DataPreprocess.process_netCDF_v2 import get_unique_vars +from DataPreprocess.process_netCDF_v2 import Calc_data_stat +from metadata import MetaData +#from base_dataset import VarLenFeatureVideoDataset +from collections import OrderedDict +from tensorflow.contrib.training import HParams +from mpi4py import MPI +import glob +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +class MovingMnist(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(MovingMnist, self).__init__(*args, **kwargs) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + print("features in dataset:",feature.keys()) + self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) + self.image_shape = self.video_shape[1:] + self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape + + def get_default_hparams_dict(self): + default_hparams = super(MovingMnist, self).get_default_hparams_dict() + hparams = dict( + context_frames=10,#Bing: Todo oriignal is 10 + sequence_length=20,#bing: TODO original is 20, + shuffle_on_val=True, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + + @property + def jpeg_encoding(self): + return False + + + + def num_examples_per_epoch(self): + with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() + sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] + return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + + def filter(self, serialized_example): + return tf.convert_to_tensor(True) + + + def make_dataset_v2(self, batch_size): + def parser(serialized_example): + seqs = OrderedDict() + keys_to_features = { + 'width': tf.FixedLenFeature([], tf.int64), + 'height': tf.FixedLenFeature([], tf.int64), + 'sequence_length': tf.FixedLenFeature([], tf.int64), + 'channels': tf.FixedLenFeature([], tf.int64), + 'images/encoded': tf.VarLenFeature(tf.float32) + } + + # for i in range(20): + # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) + parsed_features = tf.parse_single_example(serialized_example, keys_to_features) + print ("Parse features", parsed_features) + seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) + #width = tf.sparse_tensor_to_dense(parsed_features["width"]) + # height = tf.sparse_tensor_to_dense(parsed_features["height"]) + # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) + # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) + images = [] + print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) + images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + seqs["images"] = images + return seqs + filenames = self.filenames + print ("FILENAMES",filenames) + #TODO: + #temporal_filenames = self.temporal_filenames + shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) + if shuffle: + random.shuffle(filenames) + dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) # todo: what is buffer_size + print("files", self.filenames) + print("mode", self.mode) + dataset = dataset.filter(self.filter) + if shuffle: + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) + else: + dataset = dataset.repeat(self.num_epochs) + + num_parallel_calls = None if shuffle else 1 + dataset = dataset.apply(tf.contrib.data.map_and_batch( + parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) + #dataset = dataset.map(parser) + # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) + # dataset = dataset.apply(tf.contrib.data.map_and_batch( + # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs + dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU + return dataset + + + + def make_batch(self, batch_size): + dataset = self.make_dataset_v2(batch_size) + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + +def save_tf_record(output_fname, sequences): + with tf.python_io.TFRecordWriter(output_fname) as writer: + for i in range(len(sequences)): + sequence = sequences[:,i,:,:,:] + num_frames = len(sequence) + height, width = sequence[0,:,:,0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(1), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + +def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128 + """ + Read the moving_mnst data which is npz format, and save it to tfrecords files + The shape of dat_npz is [seq_length,number_samples,height,width] + moving_mnst only has one channel + + """ + os.makedirs(output_dir,exist_ok=True) + idx = 0 + num_samples = dat_npz.shape[1] + dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel] + print("data_npz_shape",dat_npz.shape) + dat_npz = dat_npz.astype(np.float32) + dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value + while idx < num_samples - sequences_per_file: + sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:] + output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, sequences) + idx = idx + sequences_per_file + return None + + +def write_sequence_file(output_dir,seq_length,sequences_per_file): + partition_names = ["train","val","test"] + for partition_name in partition_names: + save_output_dir = os.path.join(output_dir,partition_name) + tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) + print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) + sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') + for i in range(tfCounter*sequences_per_file): + sequence_lengths_file.write("%d\n" % seq_length) + sequence_lengths_file.close() + + +def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)") + img_len = imgs.shape[0] + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + for i in range(img_len): + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet') + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + plt.savefig(os.path.join(output_png_dir, label + "_" + str(idx) + ".jpg")) + print("images_saved") + plt.clf() + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") + parser.add_argument("output_dir", type=str) + parser.add_argument("-sequences_per_file",type=int,default=2) + args = parser.parse_args() + current_path = os.getcwd() + data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy")) + print("data in minist_test_Seq shape",data.shape) + seq_length = data.shape[0] + height = data.shape[2] + width = data.shape[3] + num_samples = data.shape[1] + max_npz = np.max(data) + min_npz = np.min(data) + print("max_npz,",max_npz) + print("min_npz",min_npz) + #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing + dat_train = data[:,:6000,:,:] + dat_val = data[:,6000:7000,:,:] + dat_test = data[:,7000:,:] + #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz") + #save train + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width) + #save val + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width) + #save test + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width) + #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40) +if __name__ == '__main__': + main() + diff --git a/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py index 321f6cc7e05320cf83e1173d8004429edf07ec24..c4a095dc8fc3abdbd87c1eaf79adcd7dad99020b 100644 --- a/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py +++ b/video_prediction_savp/video_prediction/layers/BasicConvLSTMCell.py @@ -88,10 +88,14 @@ class BasicConvLSTMCell(ConvRNNCell): else: c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) - + print("concat1:",concat) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) - + print("input gate i:",i) + print("new_input j:",j) + print("forget gate:",f) + print("output gate:",o) + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * tf.nn.sigmoid(o) @@ -100,6 +104,8 @@ class BasicConvLSTMCell(ConvRNNCell): new_state = LSTMStateTuple(new_c, new_h) else: new_state = tf.concat(axis = 3, values = [new_c, new_h]) + print("new h", new_h) + print("new state",new_state) return new_h, new_state @@ -135,9 +141,14 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No matrix = tf.get_variable( "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) if len(args) == 1: + print("args[0]:",args[0]) res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') + print("res1:",res) else: + print("matrix:",matrix) + print("tf.concat(axis = 3, values = args):",tf.concat(axis = 3, values = args)) res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') + print("res2:",res) if not bias: return res bias_term = tf.get_variable( @@ -146,3 +157,4 @@ def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=No initializer = tf.constant_initializer( bias_start, dtype = dtype)) return res + bias_term + diff --git a/video_prediction_savp/video_prediction/layers/layer_def.py b/video_prediction_savp/video_prediction/layers/layer_def.py index a59643c7a6d69141134ec01c9b147c4798bfed8e..273b5eaee3cab703841b214ccc09ef190b6dd3ae 100644 --- a/video_prediction_savp/video_prediction/layers/layer_def.py +++ b/video_prediction_savp/video_prediction/layers/layer_def.py @@ -3,8 +3,8 @@ import tensorflow as tf import numpy as np - weight_decay = 0.0005 + def _activation_summary(x): """Helper to create summaries for activations. Creates a summary that provides a histogram of activations. @@ -55,8 +55,7 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): - print("conv_layer activation function",activate) - + print("conv_layer activation function",activate) with tf.variable_scope('{0}_conv'.format(idx)) as scope: input_channels = inputs.get_shape()[-1] @@ -74,6 +73,8 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.co conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) elif activate == "leaky_relu": conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "sigmoid": + conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx)) else: raise ("activation function is not correct") return conv_rect @@ -157,4 +158,4 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): def bn_layers_wrapper(inputs, is_training): pass - \ No newline at end of file + diff --git a/video_prediction_savp/video_prediction/models/__init__.py b/video_prediction_savp/video_prediction/models/__init__.py index 6d7323f3750949b0ddb411d4a98934928537bc53..8010c4eeb2123fd94995c6474e7e1c8af6b02113 100644 --- a/video_prediction_savp/video_prediction/models/__init__.py +++ b/video_prediction_savp/video_prediction/models/__init__.py @@ -21,7 +21,6 @@ def get_model_class(model): 'vae': 'VanillaVAEVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', 'mcnet': 'McNetVideoPredictionModel', - } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py index 846621d8ca1e235c39618951be86fe184a2d974d..0d3bf6e4b554c70671d4678b530688c44f999b77 100644 --- a/video_prediction_savp/video_prediction/models/base_model.py +++ b/video_prediction_savp/video_prediction/models/base_model.py @@ -3,12 +3,10 @@ import itertools import os import re from collections import OrderedDict - import numpy as np import tensorflow as tf from tensorflow.contrib.training import HParams from tensorflow.python.util import nest - import video_prediction as vp from video_prediction.utils import tf_utils from video_prediction.utils.tf_utils import compute_averaged_gradients, reduce_tensors, local_device_setter, \ @@ -244,7 +242,9 @@ class BaseVideoPredictionModel(object): savers.append(saver) restore_op = [saver.saver_def.restore_op_name for saver in savers] sess.run(restore_op) - + return True + else: + return False class VideoPredictionModel(BaseVideoPredictionModel): def __init__(self, @@ -487,6 +487,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"} if self.num_gpus <= 1: # cpu or 1 gpu + print("self.inputs:>20200822",self.inputs) outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs) self.outputs, self.eval_outputs = outputs_tuple self.d_losses, self.g_losses, g_losses_post = losses_tuple diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py index c510d050c89908d0e06fe0f1a66e355e61c90530..039533864f34d1608c5a10d4a664d40ce73594a7 100644 --- a/video_prediction_savp/video_prediction/models/savp_model.py +++ b/video_prediction_savp/video_prediction/models/savp_model.py @@ -689,6 +689,10 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell): def generator_given_z_fn(inputs, mode, hparams): # all the inputs needs to have the same length for unrolling the rnn print("inputs.items",inputs.items()) + #20200822 bing + inputs ={"images":inputs["images"]} + print("inputs 20200822:",inputs) + #20200822 inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} cell = SAVPCell(inputs, mode, hparams) diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 5a8a2e1f3fffe5c66d5b93e53137300bf792317e..796486a453f9dc6807928deeb2b8962e2908a4f2 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -28,32 +28,29 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): self.sequence_length = self.hparams.sequence_length self.predict_frames = self.sequence_length - self.context_frames self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + + def get_default_hparams_dict(self): """ The keys of this dict define valid hyperparameters for instances of this class. A class inheriting from this one should override this method if it has a different set of hyperparameters. - Returns: A dict with the following hyperparameters. - batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. - max_steps: number of training steps. - context_frames: the number of ground-truth frames to pass :qin at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. - """ + max_epochs: number of training epochs, each epoch equal to sample_size/batch_size + loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user + """ default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() print ("default hparams",default_hparams) hparams = dict( batch_size=16, lr=0.001, max_epochs=3000, + loss_fun = None ) return dict(itertools.chain(default_hparams.items(), hparams.items())) @@ -70,9 +67,22 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): #self.context_frames_loss = tf.reduce_mean( # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) # This is the loss function (RMSE): - self.total_loss = tf.reduce_mean( - tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0])) - + #This is loss function only for 1 channel (temperature RMSE) + if self.loss_fun == "rmse": + self.total_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0])) + elif self.loss_fun == "cross_entropy": + x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) + x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1]) + bce = tf.keras.losses.BinaryCrossentropy() + self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten) + else: + raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") + + #This is the loss for only all the channels(temperature, geo500, pressure) + #self.total_loss = tf.reduce_mean( + # tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:])) + self.train_op = tf.train.AdamOptimizer( learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) self.outputs = {} @@ -84,60 +94,40 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): self.saveable_variables = [self.global_step] + global_variables return None - @staticmethod def convLSTM_cell(inputs, hidden): - - conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1", activate = "leaky_relu") - - conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2", activate = "leaky_relu") - - conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3", activate = "leaky_relu") - - y_0 = conv3 + y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] # conv lstm cell cell_shape = y_0.get_shape().as_list() + channels = cell_shape[-1] with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): - cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [3, 3], num_features = 8) + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size = [5, 5], num_features = 256) if hidden is None: hidden = cell.zero_state(y_0, tf.float32) - output, hidden = cell(y_0, hidden) - - output_shape = output.get_shape().as_list() - - z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - - conv5 = ld.transpose_conv_layer(z3, 3, 2, 8, "decode_5", activate = "leaky_relu") - - - conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6", activate = "leaky_relu") - - - x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, "decode_7", activate = "sigmoid") # set activation to linear - + #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") return x_hat, hidden def convLSTM_network(self): network_template = tf.make_template('network', VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables # create network - x_hat_context = [] x_hat = [] - hidden = None - #This is for training - for i in range(self.sequence_length): - if i < self.context_frames: - x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) - else: - x_1, hidden = network_template(x_1, hidden) - x_hat_context.append(x_1) + + # for i in range(self.sequence_length-1): + # if i < self.context_frames: + # x_1, hidden = network_template(self.x[:, i, :, :, :], hidden) + # else: + # x_1, hidden = network_template(x_1, hidden) + # x_hat_context.append(x_1) - #This is for generating video + #This is for training (optimization of convLSTM layer) hidden_g = None - for i in range(self.sequence_length): + for i in range(self.sequence_length-1): if i < self.context_frames: x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g) else: @@ -145,8 +135,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): x_hat.append(x_1_g) # pack them all together - x_hat_context = tf.stack(x_hat_context) x_hat = tf.stack(x_hat) - self.x_hat_context_frames = tf.transpose(x_hat_context, [1, 0, 2, 3, 4]) # change first dim with sec dim self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim - self.x_hat_predict_frames = self.x_hat[:,self.context_frames:,:,:,:] + self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:] +