diff --git a/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh new file mode 100755 index 0000000000000000000000000000000000000000..a81e9a1499ce2619c6d934d32396c7128bd6b565 --- /dev/null +++ b/video_prediction_savp/HPC_scripts/DataPreprocess_to_tf_movingmnist.sh @@ -0,0 +1,37 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=DataPreprocess_to_tf-out.%j +#SBATCH --error=DataPreprocess_to_tf-err.%j +#SBATCH --time=00:20:00 +#SBATCH --partition=devel +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de + + +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) + +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist + +# run Preprocessing (step 2 where Tf-records are generated) +srun python ../video_prediction/datasets/moving_mnist.py ${source_dir} ${destination_dir}/tfrecords diff --git a/video_prediction_savp/HPC_scripts/generate_movingmnist.sh b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh new file mode 100755 index 0000000000000000000000000000000000000000..f6a36d1a366a1e1492546e8a8ac14760cc434098 --- /dev/null +++ b/video_prediction_savp/HPC_scripts/generate_movingmnist.sh @@ -0,0 +1,44 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +##SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --output=generate_era5-out.%j +#SBATCH --error=generate_era5-err.%j +#SBATCH --time=00:20:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=develgpus +#SBATCH --mail-type=ALL +#SBATCH --mail-user=b.gong@fz-juelich.de +##jutil env activate -p cjjsc42 + +# Name of virtual environment +VIRT_ENV_NAME="vp" + +# Loading mouldes +source ../env_setup/modules_train.sh +# Activate virtual environment if needed (and possible) +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# declare directory-variables which will be modified appropriately during Preprocessing (invoked by mpi_split_data_multi_years.py) +source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist +checkpoint_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist +results_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist +# name of model +model=convLSTM + +# run postprocessing/generation of model results including evaluation metrics +srun python -u ../scripts/generate_movingmnist.py \ +--input_dir ${source_dir}/ --dataset_hparams sequence_length=20 --checkpoint ${checkpoint_dir}/${model} \ +--mode test --model ${model} --results_dir ${results_dir}/${model} --batch_size 2 --dataset era5 > generate_era5-out.out + +#srun python scripts/train.py --input_dir data/era5 --dataset era5 --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp diff --git a/video_prediction_savp/HPC_scripts/train_movingmnist.sh b/video_prediction_savp/HPC_scripts/train_movingmnist.sh index 1ceaebbfcfe62bca30c9dd728b61f6ed2bc22d4d..f62d333dbf01db0affbf72a3e1ef1ecd96b94ec7 100755 --- a/video_prediction_savp/HPC_scripts/train_movingmnist.sh +++ b/video_prediction_savp/HPC_scripts/train_movingmnist.sh @@ -38,7 +38,7 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model # for choosing the model, convLSTM,savp, mcnet,vae,convLSTM_Loliver model=convLSTM -model_hparams=../hparams/era5/${model}/model_hpain_movingmnist.shrams.json +model_hparams=../hparams/era5/${model}/model_hparams.json # rund training srun python ../scripts/train_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ --checkpoint ${destination_dir}/${model}/ diff --git a/video_prediction_savp/scripts/generate_movingmnist.py b/video_prediction_savp/scripts/generate_movingmnist.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec2af488c81dddeef6bff2deeb867c4e7b4ffed --- /dev/null +++ b/video_prediction_savp/scripts/generate_movingmnist.py @@ -0,0 +1,822 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import math +import random +import cv2 +import numpy as np +import tensorflow as tf +import pickle +from random import seed +import random +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.animation as animation +import pandas as pd +import re +from video_prediction import datasets, models +from matplotlib.colors import LinearSegmentedColormap +#from matplotlib.ticker import MaxNLocator +#from video_prediction.utils.ffmpeg_gif import save_gif +from skimage.metrics import structural_similarity as ssim +import datetime +# Scarlet 2020/05/28: access to statistical values in json file +from os import path +import sys +sys.path.append(path.abspath('../video_prediction/datasets/')) +from era5_dataset_v2 import Norm_data +from os.path import dirname +from netCDF4 import Dataset,date2num +from metadata import MetaData as MetaData + +def set_seed(seed): + if seed is not None: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def get_coordinates(metadata_fname): + """ + Retrieves the latitudes and longitudes read from the metadata json file. + """ + md = MetaData(json_file=metadata_fname) + md.get_metadata_from_file(metadata_fname) + + try: + print("lat:",md.lat) + print("lon:",md.lon) + return md.lat, md.lon + except: + raise ValueError("Error when handling: '"+metadata_fname+"'") + + +def load_checkpoints_and_create_output_dirs(checkpoint,dataset,model): + if checkpoint: + checkpoint_dir = os.path.normpath(checkpoint) + if not os.path.isdir(checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % checkpoint) + options = json.loads(f.read()) + dataset = dataset or options['dataset'] + model = model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: + dataset_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("dataset_hparams.json was not loaded because it does not exist") + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict = json.loads(f.read()) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + else: + if not dataset: + raise ValueError('dataset is required when checkpoint is not specified') + if not model: + raise ValueError('model is required when checkpoint is not specified') + + return options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict + + + +def setup_dataset(dataset,input_dir,mode,seed,num_epochs,dataset_hparams,dataset_hparams_dict): + VideoDataset = datasets.get_dataset_class(dataset) + dataset = VideoDataset( + input_dir, + mode = mode, + num_epochs = num_epochs, + seed = seed, + hparams_dict = dataset_hparams_dict, + hparams = dataset_hparams) + return dataset + + +def setup_dirs(input_dir,results_png_dir): + input_dir = args.input_dir + temporal_dir = os.path.split(input_dir)[0] + "/hickle/splits/" + print ("temporal_dir:",temporal_dir) + + +def update_hparams_dict(model_hparams_dict,dataset): + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + return hparams_dict + + +def psnr(img1, img2): + mse = np.mean((img1 - img2) ** 2) + if mse == 0: return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) + + +def setup_num_samples_per_epoch(num_samples, dataset): + if num_samples: + if num_samples > dataset.num_examples_per_epoch(): + raise ValueError('num_samples cannot be larger than the dataset') + num_examples_per_epoch = num_samples + else: + num_examples_per_epoch = dataset.num_examples_per_epoch() + #if num_examples_per_epoch % args.batch_size != 0: + # raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + return num_examples_per_epoch + + +def initia_save_data(): + sample_ind = 0 + gen_images_all = [] + #Bing:20200410 + persistent_images_all = [] + input_images_all = [] + return sample_ind, gen_images_all,persistent_images_all, input_images_all + + +def write_params_to_results_dir(args,output_dir,dataset,model): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(dataset.hparams.values(), sort_keys = True, indent = 4)) + with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4)) + return None + +def get_one_seq_and_time(input_images,i): + assert (len(np.array(input_images).shape)==5) + input_images_ = input_images[i,:,:,:,:] + return input_images_ + + +def denorm_images_all_channels(input_images_): + input_images_all_channles_denorm = [] + input_images_ = np.array(input_images_) + input_images_denorm = input_images_ * 255.0 + #print("input_images_denorm shape",input_images_denorm.shape) + return input_images_denorm + +def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)") + img_len = imgs.shape[0] + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + for i in range(img_len): + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet') + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + plt.savefig(os.path.join(output_png_dir, label + "_" + str(idx) + ".jpg")) + print("images_saved") + plt.clf() + + + +def get_persistence(ts): + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type = str, required = True, + help = "either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--results_dir", type = str, default = 'results', + help = "ignored if output_gif_dir is specified") + parser.add_argument("--checkpoint", + help = "directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--mode", type = str, choices = ['train','val', 'test'], default = 'val', + help = 'mode for dataset, val or test.') + parser.add_argument("--dataset", type = str, help = "dataset class name") + parser.add_argument("--dataset_hparams", type = str, + help = "a string of comma separated list of dataset hyperparameters") + parser.add_argument("--model", type = str, help = "model class name") + parser.add_argument("--model_hparams", type = str, + help = "a string of comma separated list of model hyperparameters") + parser.add_argument("--batch_size", type = int, default = 8, help = "number of samples in batch") + parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)") + parser.add_argument("--num_epochs", type = int, default = 1) + parser.add_argument("--num_stochastic_samples", type = int, default = 1) + parser.add_argument("--gif_length", type = int, help = "default is sequence_length") + parser.add_argument("--fps", type = int, default = 4) + parser.add_argument("--gpu_mem_frac", type = float, default = 0.95, help = "fraction of gpu memory to use") + parser.add_argument("--seed", type = int, default = 7) + args = parser.parse_args() + set_seed(args.seed) + + dataset_hparams_dict = {} + model_hparams_dict = {} + + options,dataset,model, checkpoint_dir,dataset_hparams_dict,model_hparams_dict = load_checkpoints_and_create_output_dirs(args.checkpoint,args.dataset,args.model) + print("Step 1 finished") + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + #setup dataset and model object + input_dir_tf = os.path.join(args.input_dir, "tfrecords") # where tensorflow records are stored + dataset = setup_dataset(dataset,input_dir_tf,args.mode,args.seed,args.num_epochs,args.dataset_hparams,dataset_hparams_dict) + + print("Step 2 finished") + VideoPredictionModel = models.get_model_class(model) + + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + + model = VideoPredictionModel( + mode = args.mode, + hparams_dict = hparams_dict, + hparams = args.model_hparams) + + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames #context_Frames is the number of input frames + + num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset) + + inputs = dataset.make_batch(args.batch_size) + print("inputs",inputs) + input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} + print("input_phs",input_phs) + + + # Build graph + with tf.variable_scope(''): + model.build_graph(input_phs) + + #Write the update hparameters into results_dir + write_params_to_results_dir(args=args,output_dir=args.results_dir,dataset=dataset,model=model) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_mem_frac) + config = tf.ConfigProto(gpu_options = gpu_options, allow_soft_placement = True) + sess = tf.Session(config = config) + sess.graph.as_default() + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) + + #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data + sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data() + + is_first=True + #loop for in samples + while sample_ind < 5: + gen_images_stochastic = [] + if args.num_samples and sample_ind >= args.num_samples: + break + try: + input_results = sess.run(inputs) + input_images = input_results["images"] + #get the intial times + t_starts = input_results["T_start"] + except tf.errors.OutOfRangeError: + break + + #Get prediction values + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] + print("gen_images 20200822:",np.array(gen_images).shape) + #Loop in batch size + for i in range(args.batch_size): + + #get one seq and the corresponding start time point + input_images_ = get_one_seq_and_time(input_images,i) + + #Renormalized data for inputs + input_images_denorm = denorm_images_all_channels(input_images_) + print("input_images_denorm",input_images_denorm[0][0]) + + #Renormalized data for inputs + gen_images_ = gen_images[i] + gen_images_denorm = denorm_images_all_channels(gen_images_) + print("gene_images_denorm:",gen_images_denorm[0][0]) + + #Generate images inputs + plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],idx = sample_ind + i, label="Ground Truth",output_png_dir=args.results_dir) + + #Generate forecast images + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],idx = sample_ind + i,label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + + #TODO: Scaret plot persistence image + #implment get_persistence() function + + #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation + + sample_ind += args.batch_size + + + #for input_image in input_images_: + +# for stochastic_sample_ind in range(args.num_stochastic_samples): +# input_images_all.extend(input_images) +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(input_images_all), input_files) + + +# gen_images_stochastic.append(gen_images) +# #print("Stochastic_sample,", stochastic_sample_ind) +# for i in range(args.batch_size): +# #bing:20200417 +# t_stampe = test_temporal_pkl[sample_ind+i] +# print("timestamp:",type(t_stampe)) +# persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) +# print ("persistent ts",persistent_ts) +# persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) +# persistent_X = X_test[persistent_idx:persistent_idx+context_frames + future_length] +# print("persistent index in test set:", persistent_idx) +# print("persistent_X.shape",persistent_X.shape) +# persistent_images_all.append(persistent_X) + +# cmap_name = 'my_list' +# if sample_ind < 100: +# #name = '_Stochastic_id_' + str(stochastic_sample_ind) + 'Batch_id_' + str( +# # sample_ind) + " + Sample_" + str(i) +# name = '_Stochastic_id_' + str(stochastic_sample_ind) + "_Time_"+ t_stampe[0].strftime("%Y%m%d-%H%M%S") +# print ("name",name) +# gen_images_ = np.array(list(input_images[i,:context_frames]) + list(gen_images[i,-future_length:, :])) +# #gen_images_ = gen_images[i, :] +# input_images_ = input_images[i, :] +# #Bing:20200417 +# #persistent_images = ? +# #+++Scarlet:20200528 +# #print('Scarlet1') +# input_gen_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(gen_images_[:, :, :, 0],'T2',norm) +# persistent_diff = norm_cls.denorm_var(input_images_[:, :, :,0], 'T2', norm) - norm_cls.denorm_var(persistent_X[:, :, :, 0], 'T2',norm) +# #---Scarlet:20200528 +# gen_mse_avg_ = [np.mean(input_gen_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse +# persistent_mse_avg_ = [np.mean(persistent_diff[frame, :, :] ** 2) for frame in +# range(sequence_length)] # return the list with 10 (sequence) mse + +# fig = plt.figure(figsize=(18,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# ts = list(range(10,20)) #[10,11,12,..] +# xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +# ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] + +# for t in ts: + +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet2') +# input_image = norm_cls.denorm_var(input_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(input_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) +# if t == 0: +# plt.setp([ax1], xticks = list(np.linspace(0, 64, 3)), xticklabels = xlables, yticks = list(np.linspace(0, 64, 3)), yticklabels = ylabels) +# plt.ylabel("Ground Truth", fontsize=10) +# plt.savefig(os.path.join(args.output_png_dir, "Ground_Truth_Sample_" + str(name) + ".jpg")) +# plt.clf() + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) + +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #+++Scarlet:20200528 +# #print('Scarlet3') +# gen_image = norm_cls.denorm_var(gen_images_[t, :, :, 0], 'T2', norm) +# #---Scarlet:20200528 +# plt.imshow(gen_image, cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Predicted_Sample_" + str(name) + ".jpg")) +# plt.clf() + + +# fig = plt.figure(figsize=(12,6)) +# gs = gridspec.GridSpec(1, 10) +# gs.update(wspace = 0., hspace = 0.) +# for t in ts: +# #if t==0 : ax1=plt.subplot(gs[t]) +# ax1 = plt.subplot(gs[ts.index(t)]) +# #persistent_image = persistent_X[t, :, :, 0] * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# plt.imshow(persistent_X[t, :, :, 0], cmap = 'jet', vmin=270, vmax=300) +# ax1.title.set_text("t = " + str(t+1-10)) +# plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + +# plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) +# plt.clf() + + +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: +# pickle.dump(list(persistent_images_all), input_files) +# print ("Save persistent all") +# if is_first: +# gen_images_all = gen_images_stochastic +# is_first = False +# else: +# gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + +# if args.num_stochastic_samples == 1: +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: +# pickle.dump(list(gen_images_all[0]), gen_files) +# print ("Save generate all") +# else: +# with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: +# pickle.dump(list(gen_images_stochastic), gen_files) +# with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: +# pickle.dump(list(gen_images_all), gen_files) + +# sample_ind += args.batch_size + + +# with open(os.path.join(args.output_png_dir, "input_images_all.pkl"),"rb") as input_files: +# input_images_all = pickle.load(input_files) + +# with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"),"rb") as gen_files: +# gen_images_all = pickle.load(gen_files) + +# with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"),"rb") as gen_files: +# persistent_images_all = pickle.load(gen_files) + +# #+++Scarlet:20200528 +# #print('Scarlet4') +# input_images_all = np.array(input_images_all) +# input_images_all = norm_cls.denorm_var(input_images_all, 'T2', norm) +# #---Scarlet:20200528 +# persistent_images_all = np.array(persistent_images_all) +# if len(np.array(gen_images_all).shape) == 6: +# for i in range(len(gen_images_all)): +# #+++Scarlet:20200528 +# #print('Scarlet5') +# gen_images_all_stochastic = np.array(gen_images_all)[i,:,:,:,:,:] +# gen_images_all_stochastic = norm_cls.denorm_var(gen_images_all_stochastic, 'T2', norm) +# #gen_images_all_stochastic = np.array(gen_images_all_stochastic) * (321.46630859375 - 235.2141571044922) + 235.2141571044922 +# #---Scarlet:20200528 +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction_stochastic_{}.txt'.format(i)), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:, i + 10, :, :, 0] - gen_images_all_stochastic[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# psnr_model = psnr(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:, i + 10, :, :, 0], gen_images_all_stochastic[:, i + 9, :, :, 0], +# data_range = max(gen_images_all_stochastic[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results_stochastic_{}.pkl".format(i)), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all_stochastic.shape)) + +# else: +# #+++Scarlet:20200528 +# #print('Scarlet6') +# gen_images_all = np.array(gen_images_all) +# gen_images_all = norm_cls.denorm_var(gen_images_all, 'T2', norm) +# #---Scarlet:20200528 + +# # mse_model = np.mean((input_images_all[:, 1:,:,:,0] - gen_images_all[:, 1:,:,:,0])**2) # look at all timesteps except the first +# # mse_model_last = np.mean((input_images_all[:, future_length-1,:,:,0] - gen_images_all[:, future_length-1,:,:,0])**2) +# # mse_prev = np.mean((input_images_all[:, :-1,:,:,0] - gen_images_all[:, 1:,:,:,0])**2 ) +# mse_all = [] +# psnr_all = [] +# ssim_all = [] +# persistent_mse_all = [] +# persistent_psnr_all = [] +# persistent_ssim_all = [] +# f = open(os.path.join(args.output_png_dir, 'prediction_scores_4prediction.txt'), 'w') +# for i in range(future_length): +# mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - gen_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first +# persistent_mse_model = np.mean((input_images_all[:1268, i + 10, :, :, 0] - persistent_images_all[:, i + 9, :, :, +# 0]) ** 2) # look at all timesteps except the first + +# psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0]) +# ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], gen_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:, i + 9, :, :, 0].flatten()) - min( +# input_images_all[:, i + 10, :, :, 0].flatten())) +# persistent_psnr_model = psnr(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0]) +# persistent_ssim_model = ssim(input_images_all[:1268, i + 10, :, :, 0], persistent_images_all[:, i + 9, :, :, 0], +# data_range = max(gen_images_all[:1268, i + 9, :, :, 0].flatten()) - min(input_images_all[:1268, i + 10, :, :, 0].flatten())) +# mse_all.extend([mse_model]) +# psnr_all.extend([psnr_model]) +# ssim_all.extend([ssim_model]) +# persistent_mse_all.extend([persistent_mse_model]) +# persistent_psnr_all.extend([persistent_psnr_model]) +# persistent_ssim_all.extend([persistent_ssim_model]) +# results = {"mse": mse_all, "psnr": psnr_all, "ssim": ssim_all} + +# persistent_results = {"mse": persistent_mse_all, "psnr": persistent_psnr_all, "ssim": persistent_ssim_all} +# f.write("##########Predicted Frame {}\n".format(str(i + 1))) +# f.write("Model MSE: %f\n" % mse_model) +# # f.write("Previous Frame MSE: %f\n" % mse_prev) +# f.write("Model PSNR: %f\n" % psnr_model) +# f.write("Model SSIM: %f\n" % ssim_model) + +# pickle.dump(results, open(os.path.join(args.output_png_dir, "results.pkl"), "wb")) +# pickle.dump(persistent_results, open(os.path.join(args.output_png_dir, "persistent_results.pkl"), "wb")) +# # f.write("Previous frame PSNR: %f\n" % psnr_prev) +# f.write("Shape of X_test: " + str(input_images_all.shape)) +# f.write("") +# f.write("Shape of X_hat: " + str(gen_images_all.shape) + +if __name__ == '__main__': + main() + + #psnr_model = psnr(input_images_all[:, :10, :, :, 0], gen_images_all[:, :10, :, :, 0]) + #psnr_model_last = psnr(input_images_all[:, 10, :, :, 0], gen_images_all[:,10, :, :, 0]) + #psnr_prev = psnr(input_images_all[:, :, :, :, 0], input_images_all[:, 1:10, :, :, 0]) + + # ims = [] + # fig = plt.figure() + # for frame in range(20): + # input_gen_diff = np.mean((np.array(gen_images_all) - np.array(input_images_all))**2, axis=0)[frame, :,:,0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(input_gen_diff, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame +1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval=1000, blit = True, repeat_delay=2000) + # ani.save(os.path.join(args.output_png_dir, "Mean_Frames.mp4")) + # plt.close("all") + + # ims = [] + # fig = plt.figure() + # for frame in range(19): + # pix_std= np.std((np.array(gen_images_all) - np.array(input_images_all))**2, axis = 0)[frame, :,:, 0] # Get the first prediction frame (batch,height, width, channel) + # #pix_mean = np.mean(input_gen_diff, axis = 0) + # #pix_std = np.std(input_gen_diff, axis=0) + # im = plt.imshow(pix_std, interpolation = 'none',cmap='PuBu') + # if frame == 0: + # fig.colorbar(im) + # ttl = plt.text(1.5, 2, "Frame_" + str(frame+1)) + # ims.append([im, ttl]) + # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # ani.save(os.path.join(args.output_png_dir, "Std_Frames.mp4")) + + # seed(1) + # s = random.sample(range(len(gen_images_all)), 100) + # print("******KDP******") + # #kernel density plot for checking the model collapse + # fig = plt.figure() + # kdp = sns.kdeplot(gen_images_all[s].flatten(), shade=True, color="r", label = "Generate Images") + # kdp = sns.kdeplot(input_images_all[s].flatten(), shade=True, color="b", label = "Ground True") + # kdp.set(xlabel = 'Temperature (K)', ylabel = 'Probability') + # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) + # plt.clf() + + #line plot for evaluating the prediction and groud-truth + # for i in [0,3,6,9,12,15,18]: + # fig = plt.figure() + # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) + # #plt.scatter(gen_images_all[:,0,:,:].flatten(),input_images_all[:,0,:,:].flatten(),s=0.3) + # plt.xlabel("Prediction") + # plt.ylabel("Real values") + # plt.title("Frame_{}".format(i+1)) + # plt.plot([250,300], [250,300],color="black") + # plt.savefig(os.path.join(args.output_png_dir,"pred_real_frame_{}.png".format(str(i)))) + # plt.clf() + # + # mse_model_by_frames = np.mean((input_images_all[:, :, :, :, 0][s] - gen_images_all[:, :, :, :, 0][s]) ** 2,axis=(2,3)) #return (batch, sequence) + # x = [str(i+1) for i in list(range(19))] + # fig,axis = plt.subplots() + # mean_f = np.mean(mse_model_by_frames, axis = 0) + # median = np.median(mse_model_by_frames, axis=0) + # q_low = np.quantile(mse_model_by_frames, q=0.25, axis=0) + # q_high = np.quantile(mse_model_by_frames, q=0.75, axis=0) + # d_low = np.quantile(mse_model_by_frames,q=0.1, axis=0) + # d_high = np.quantile(mse_model_by_frames, q=0.9, axis=0) + # plt.fill_between(x, d_high, d_low, color="ghostwhite",label="interdecile range") + # plt.fill_between(x,q_high, q_low , color = "lightgray", label="interquartile range") + # plt.plot(x, median, color="grey", linewidth=0.6, label="Median") + # plt.plot(x, mean_f, color="peachpuff",linewidth=1.5, label="Mean") + # plt.title(f'MSE percentile') + # plt.xlabel("Frames") + # plt.legend(loc=2, fontsize=8) + # plt.savefig(os.path.join(args.output_png_dir,"mse_percentiles.png")) + + +## +## +## # fig = plt.figure() +## # gs = gridspec.GridSpec(4,6) +## # gs.update(wspace = 0.7,hspace=0.8) +## # ax1 = plt.subplot(gs[0:2,0:3]) +## # ax2 = plt.subplot(gs[0:2,3:],sharey=ax1) +## # ax3 = plt.subplot(gs[2:4,0:3]) +## # ax4 = plt.subplot(gs[2:4,3:]) +## # xlables = [round(i,2) for i in list(np.linspace(np.min(lon),np.max(lon),5))] +## # ylabels = [round(i,2) for i in list(np.linspace(np.max(lat),np.min(lat),5))] +## # plt.setp([ax1,ax2,ax3],xticks=list(np.linspace(0,64,5)), xticklabels=xlables ,yticks=list(np.linspace(0,64,5)),yticklabels=ylabels) +## # ax1.title.set_text("(a) Ground Truth") +## # ax2.title.set_text("(b) SAVP") +## # ax3.title.set_text("(c) Diff.") +## # ax4.title.set_text("(d) MSE") +## # +## # ax1.xaxis.set_tick_params(labelsize=7) +## # ax1.yaxis.set_tick_params(labelsize = 7) +## # ax2.xaxis.set_tick_params(labelsize=7) +## # ax2.yaxis.set_tick_params(labelsize = 7) +## # ax3.xaxis.set_tick_params(labelsize=7) +## # ax3.yaxis.set_tick_params(labelsize = 7) +## # +## # init_images = np.zeros((input_images_.shape[1], input_images_.shape[2])) +## # print("inti images shape", init_images.shape) +## # xdata, ydata = [], [] +## # #plot1 = ax1.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # #plot2 = ax2.imshow(init_images, cmap='jet', vmin =0, vmax = 1) +## # plot1 = ax1.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # plot2 = ax2.imshow(init_images, cmap='jet', vmin = 270, vmax = 300) +## # #x = np.linspace(0, 64, 64) +## # #y = np.linspace(0, 64, 64) +## # #plot1 = ax1.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # #plot2 = ax2.contourf(x,y,init_images, cmap='jet', vmin = np.min(input_images), vmax = np.max(input_images)) +## # fig.colorbar(plot1, ax=ax1).ax.tick_params(labelsize=7) +## # fig.colorbar(plot2, ax=ax2).ax.tick_params(labelsize=7) +## # +## # cm = LinearSegmentedColormap.from_list( +## # cmap_name, "bwr", N = 5) +## # +## # plot3 = ax3.imshow(init_images, vmin=-20, vmax=20, cmap=cm)#cmap = 'PuBu_r', +## # #plot3 = ax3.imshow(init_images, vmin = -1, vmax = 1, cmap = cm) # cmap = 'PuBu_r', +## # plot4, = ax4.plot([], [], color = "r") +## # ax4.set_xlim(0, future_length-1) +## # ax4.set_ylim(0, 20) +## # #ax4.set_ylim(0, 0.5) +## # ax4.set_xlabel("Frames", fontsize=10) +## # #ax4.set_ylabel("MSE", fontsize=10) +## # ax4.xaxis.set_tick_params(labelsize=7) +## # ax4.yaxis.set_tick_params(labelsize=7) +## # +## # +## # plots = [plot1, plot2, plot3, plot4] +## # +## # #fig.colorbar(plots[1], ax = [ax1, ax2]) +## # +## # fig.colorbar(plots[2], ax=ax3).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot1[0], ax=ax1).ax.tick_params(labelsize=7) +## # #fig.colorbar(plot2[1], ax=ax2).ax.tick_params(labelsize=7) +## # +## # def animation_sample(t): +## # input_image = input_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # gen_image = gen_images_[t, :, :, 0]* (321.46630859375-235.2141571044922) + 235.2141571044922 +## # diff_image = input_gen_diff[t,:,:] +## # # p = sns.lineplot(x=x,y=data,color="b") +## # # p.tick_params(labelsize=17) +## # # plt.setp(p.lines, linewidth=6) +## # plots[0].set_data(input_image) +## # plots[1].set_data(gen_image) +## # #plots[0] = ax1.contourf(x, y, input_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # #plots[1] = ax2.contourf(x, y, gen_image, cmap = 'jet', vmin = np.min(input_images),vmax = np.max(input_images)) +## # plots[2].set_data(diff_image) +## # +## # if t >= future_length: +## # #data = gen_mse_avg_[:t + 1] +## # # x = list(range(len(gen_mse_avg_)))[:t+1] +## # xdata.append(t-future_length) +## # print("xdata", xdata) +## # ydata.append(gen_mse_avg_[t]) +## # print("ydata", ydata) +## # plots[3].set_data(xdata, ydata) +## # fig.suptitle("Predicted Frame " + str(t-future_length)) +## # else: +## # #plots[3].set_data(xdata, ydata) +## # fig.suptitle("Context Frame " + str(t)) +## # return plots +## # +## # ani = animation.FuncAnimation(fig, animation_sample, frames=len(gen_mse_avg_), interval = 1000, +## # repeat_delay=2000) +## # ani.save(os.path.join(args.output_png_dir, "Sample_" + str(name) + ".mp4")) +## +#### else: +#### pass +## + + + + + + # # for i, gen_mse_avg_ in enumerate(gen_mse_avg): + # # ims = [] + # # fig = plt.figure() + # # plt.xlim(0,len(gen_mse_avg_)) + # # plt.ylim(np.min(gen_mse_avg),np.max(gen_mse_avg)) + # # plt.xlabel("Frames") + # # plt.ylabel("MSE_AVG") + # # #X = list(range(len(gen_mse_avg_))) + # # #for t, gen_mse_avg_ in enumerate(gen_mse_avg): + # # def animate_metric(j): + # # data = gen_mse_avg_[:(j+1)] + # # x = list(range(len(gen_mse_avg_)))[:(j+1)] + # # p = sns.lineplot(x=x,y=data,color="b") + # # p.tick_params(labelsize=17) + # # plt.setp(p.lines, linewidth=6) + # # ani = animation.FuncAnimation(fig, animate_metric, frames=len(gen_mse_avg_), interval = 1000, repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir, "MSE_AVG" + str(i) + ".gif")) + # # + # # + # # for i, input_images_ in enumerate(input_images): + # # #context_images_ = (input_results['images'][i]) + # # #gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # ims = [] + # # fig = plt.figure() + # # for t, input_image in enumerate(input_images_): + # # im = plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2,"Frame_" + str(t)) + # # ims.append([im,ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval= 1000, blit=True,repeat_delay=2000) + # # ani.save(os.path.join(args.output_png_dir,"groud_true_images_" + str(i) + ".gif")) + # # #plt.show() + # # + # # for i,gen_images_ in enumerate(gen_images): + # # ims = [] + # # fig = plt.figure() + # # for t, gen_image in enumerate(gen_images_): + # # im = plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # ttl = plt.text(1.5, 2, "Frame_" + str(t)) + # # ims.append([im, ttl]) + # # ani = animation.ArtistAnimation(fig, ims, interval = 1000, blit = True, repeat_delay = 2000) + # # ani.save(os.path.join(args.output_png_dir, "prediction_images_" + str(i) + ".gif")) + # + # + # # for i, gen_images_ in enumerate(gen_images): + # # #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) + # # #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) + # # #bing + # # context_images_ = (input_results['images'][i]) + # # gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + # # context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + # # plt.figure(figsize = (10,2)) + # # gs = gridspec.GridSpec(2,10) + # # gs.update(wspace=0.,hspace=0.) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2,len(str(len(gen_images_) - 1))) + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # plt.subplot(gs[t]) + # # plt.imshow(input_images[i, t, :, :, 0], interpolation = 'none') # the last index sets the channel. 0 = t2 + # # # plt.pcolormesh(X_test[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Actual', fontsize = 10) + # # + # # plt.subplot(gs[t + 10]) + # # plt.imshow(gen_images[i, t, :, :, 0], interpolation = 'none') + # # # plt.pcolormesh(X_hat[i,t,::-1,:,0], shading='bottom', cmap=plt.cm.jet) + # # plt.tick_params(axis = 'both', which = 'both', bottom = False, top = False, left = False, + # # right = False, labelbottom = False, labelleft = False) + # # if t == 0: plt.ylabel('Predicted', fontsize = 10) + # # plt.savefig(os.path.join(args.output_png_dir, gen_image_fname) + 'plot_' + str(i) + '.png') + # # plt.clf() + # + # # if args.gif_length: + # # context_and_gen_images = context_and_gen_images[:args.gif_length] + # # save_gif(os.path.join(args.output_gif_dir, gen_images_fname), + # # context_and_gen_images, fps=args.fps) + # # + # # gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) + # # for t, gen_image in enumerate(gen_images_): + # # gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) + # # if gen_image.shape[-1] == 1: + # # gen_image = np.tile(gen_image, (1, 1, 3)) + # # else: + # # gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + # # cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py index 0e250b47df28d115c8cdfc77fc708eab5e094ce6..1fcbd1cf97442f1ea440039a0bb6769473b957f3 100644 --- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py +++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py @@ -202,11 +202,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t ts_len = len(ts) ts_input = ts[:context_frames] ts_forecast = ts[context_frames:] - #print("context_frame:",context_frames) - #print("future_frame",future_length) - #print("length of ts input:",len(ts_input)) - - print("input_images_ shape in netcdf,",input_images_.shape) gen_images_ = np.array(gen_images_) output_file = os.path.join(output_dir,fl_name) @@ -289,18 +284,19 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t #Temperature: t2 = nc_file.createVariable("/forecast/{}/T2".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) t2.units = 'K' - t2[:,:,:] = gen_images_[context_frames:,:,:,0] + print ("gen_images_ 20200822:",np.array(gen_images_).shape) + t2[:,:,:] = gen_images_[context_frames-1:,:,:,0] print("NetCDF created") #mean sea level pressure msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) msl.units = 'Pa' - msl[:,:,:] = gen_images_[context_frames:,:,:,1] + msl[:,:,:] = gen_images_[context_frames-1:,:,:,1] #Geopotential at 500 gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True) gph500.units = 'm' - gph500[:,:,:] = gen_images_[context_frames:,:,:,2] + gph500[:,:,:] = gen_images_[context_frames-1:,:,:,2] print("{} created".format(output_file)) @@ -451,7 +447,8 @@ def main(): #Get prediction values feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel] - + assert gen_images.shape[1] == sequence_length-1 #The generate images seq_len should be sequence_len -1, since the last one is not used for comparing with groud truth + print("gen_images 20200822:",np.array(gen_images).shape) #Loop in batch size for i in range(args.batch_size): @@ -459,26 +456,26 @@ def main(): input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i) #generate time stamps for sequences ts = generate_seq_timestamps(t_start,len_seq=sequence_length) - + #Renormalized data for inputs stat_fl = os.path.join(args.input_dir,"pickle/statistics.json") input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"]) - print("input_images_denorm",input_images_denorm[0][0]) + print("input_images_denorm shape",np.array(input_images_denorm).shape) #Renormalized data for inputs gen_images_ = gen_images[i] gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"]) - print("gene_images_denorm:",gen_images_denorm[0][0]) + print("gene_images_denorm shape",np.array(gen_images_denorm).shape) #Save input to netCDF file init_date_str = ts[0].strftime("%Y%m%d%H") save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str)) #Generate images inputs - plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir) + plot_seq_imgs(imgs=input_images_denorm[context_frames+1:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Ground Truth",output_png_dir=args.results_dir) #Generate forecast images - plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) + plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames+1:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) #TODO: Scaret plot persistence image #implment get_persistence() function diff --git a/video_prediction_savp/scripts/train_moving_mnist.py b/video_prediction_savp/scripts/train_moving_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb9a6be0d23449dbdf30cc316815e2a33b29de1 --- /dev/null +++ b/video_prediction_savp/scripts/train_moving_mnist.py @@ -0,0 +1,357 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import errno +import json +import os +import random +import time +import numpy as np +import tensorflow as tf +from video_prediction import datasets, models +import matplotlib.pyplot as plt +from json import JSONEncoder +import pickle as pkl + + +class NumpyArrayEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return JSONEncoder.default(self, obj) + + +def add_tag_suffix(summary, tag_suffix): + summary_proto = tf.Summary() + summary_proto.ParseFromString(summary) + summary = summary_proto + + for value in summary.value: + tag_split = value.tag.split('/') + value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) + return summary.SerializeToString() + +def generate_output_dir(output_dir, model,model_hparams,logs_dir,output_dir_postfix): + if output_dir is None: + list_depth = 0 + model_fname = '' + for t in ('model=%s,%s' % (model, model_hparams)): + if t == '[': + list_depth += 1 + if t == ']': + list_depth -= 1 + if list_depth and t == ',': + t = '..' + if t in '=,': + t = '.' + if t in '[]': + t = '' + model_fname += t + output_dir = os.path.join(logs_dir, model_fname) + output_dir_postfix + return output_dir + + +def get_model_hparams_dict(model_hparams_dict): + """ + Get model_hparams_dict from json file + """ + model_hparams_dict_load = {} + if model_hparams_dict: + with open(model_hparams_dict) as f: + model_hparams_dict_load.update(json.loads(f.read())) + return model_hparams_dict + +def resume_checkpoint(resume,checkpoint,output_dir): + """ + Resume the existing model checkpoints and set checkpoint directory + """ + if resume: + if checkpoint: + raise ValueError('resume and checkpoint cannot both be specified') + checkpoint = output_dir + return checkpoint + +def set_seed(seed): + if seed is not None: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model): + + model_hparams_dict_load = {} + if model_hparams_dict: + with open(model_hparams_dict) as f: + model_hparams_dict_load.update(json.loads(f.read())) + + if checkpoint: + checkpoint_dir = os.path.normpath(checkpoint) + if not os.path.isdir(checkpoint): + checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) + with open(os.path.join(checkpoint_dir, "options.json")) as f: + print("loading options from checkpoint %s" % args.checkpoint) + options = json.loads(f.read()) + dataset = dataset or options['dataset'] + model = model or options['model'] + try: + with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: + model_hparams_dict_load.update(json.loads(f.read())) + except FileNotFoundError: + print("model_hparams.json was not loaded because it does not exist") + return dataset, model, model_hparams_dict_load + +def setup_dataset(dataset,input_dir,val_input_dir): + VideoDataset = datasets.get_dataset_class(dataset) + train_dataset = VideoDataset( + input_dir, + mode='train') + val_dataset = VideoDataset( + val_input_dir or input_dir, + mode='val') + variable_scope = tf.get_variable_scope() + variable_scope.set_use_resource(True) + return train_dataset,val_dataset,variable_scope + +def setup_model(model,model_hparams_dict,train_dataset,model_hparams): + """ + Set up model instance + """ + VideoPredictionModel = models.get_model_class(model) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, + 'repeat': train_dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=model_hparams) + return model + +def save_dataset_model_params_to_checkpoint_dir(args,output_dir,train_dataset,model): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(args), sort_keys=True, indent=4)) + with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: + f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) + with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: + f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) + return None + +def make_dataset_iterator(train_dataset, val_dataset, batch_size ): + train_tf_dataset = train_dataset.make_dataset_v2(batch_size) + train_iterator = train_tf_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + val_tf_dataset = val_dataset.make_dataset_v2(batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + val_handle = val_iterator.string_handle() + #iterator = tf.data.Iterator.from_string_handle( + # train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = train_iterator.get_next() + val = val_iterator.get_next() + return inputs,train_handle, val_handle + + +def plot_train(train_losses,val_losses,output_dir): + iterations = list(range(len(train_losses))) + plt.plot(iterations, train_losses, 'g', label='Training loss') + plt.plot(iterations, val_losses, 'b', label='validation loss') + plt.title('Training and Validation loss') + plt.xlabel('Iterations') + plt.ylabel('Loss') + plt.legend() + plt.savefig(os.path.join(output_dir,'plot_train.png')) + +def save_results_to_dict(results_dict,output_dir): + with open(os.path.join(output_dir,"results.json"),"w") as fp: + json.dump(results_dict,fp) + +def save_results_to_pkl(train_losses,val_losses, output_dir): + with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f: + pkl.dump(train_losses,f) + with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f: + pkl.dump(val_losses,f) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " + "train, val, test, etc, or a directory containing " + "the tfrecords") + parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") + parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") + parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " + "default is logs_dir/model_fname, where model_fname consists of " + "information from model and model_hparams") + parser.add_argument("--output_dir_postfix", default="") + parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') + + parser.add_argument("--dataset", type=str, help="dataset class name") + parser.add_argument("--model", type=str, help="model class name") + parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") + parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") + + parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use") + parser.add_argument("--seed",default=1234, type=int) + + args = parser.parse_args() + + #Set seed + set_seed(args.seed) + + #setup output directory + args.output_dir = generate_output_dir(args.output_dir, args.model, args.model_hparams, args.logs_dir, args.output_dir_postfix) + + #resume the existing checkpoint and set up the checkpoint directory to output directory + args.checkpoint = resume_checkpoint(args.resume,args.checkpoint,args.output_dir) + + #get model hparams dict from json file + #load the existing checkpoint related datasets, model configure (This information was stored in the checkpoint dir when last time training model) + args.dataset,args.model,model_hparams_dict = load_params_from_checkpoints_dir(args.model_hparams_dict,args.checkpoint,args.dataset,args.model) + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + #setup training val datset instance + train_dataset,val_dataset,variable_scope = setup_dataset(args.dataset,args.input_dir,args.val_input_dir) + + #setup model instance + model=setup_model(args.model,model_hparams_dict,train_dataset,args.model_hparams) + + batch_size = model.hparams.batch_size + #Create input and val iterator + inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size) + + #build model graph + #del inputs["T_start"] + model.build_graph(inputs) + + #save all the model, data params to output dirctory + save_dataset_model_params_to_checkpoint_dir(args,args.output_dir,train_dataset,model) + + with tf.name_scope("parameter_count"): + # exclude trainable variables that are replicas (used in multi-gpu setting) + trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) + + saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) + + # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero + summary_writer = tf.summary.FileWriter(args.output_dir) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) + config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) + + max_epochs = model.hparams.max_epochs #the number of epochs + num_examples_per_epoch = train_dataset.num_examples_per_epoch() + print ("number of exmaples per epoch:",num_examples_per_epoch) + steps_per_epoch = int(num_examples_per_epoch/batch_size) + total_steps = steps_per_epoch * max_epochs + global_step = tf.train.get_or_create_global_step() + #mock total_steps only for fast debugging + #total_steps = 10 + print ("Total steps for training:",total_steps) + results_dict = {} + with tf.Session(config=config) as sess: + print("parameter_count =", sess.run(parameter_count)) + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + model.restore(sess, args.checkpoint) + sess.graph.finalize() + #start_step = sess.run(model.global_step) + start_step = sess.run(global_step) + print("start_step", start_step) + # start at one step earlier to log everything without doing any training + # step is relative to the start_step + train_losses=[] + val_losses=[] + run_start_time = time.time() + for step in range(start_step,total_steps): + #global_step = sess.run(global_step):q + + print ("step:", step) + val_handle_eval = sess.run(val_handle) + + #Fetch variables in the graph + + fetches = {"train_op": model.train_op} + #fetches["latent_loss"] = model.latent_loss + fetches["summary"] = model.summary_op + + if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + fetches["global_step"] = model.global_step + fetches["total_loss"] = model.total_loss + #fetch the specific loss function only for mcnet + if model.__class__.__name__ == "McNetVideoPredictionModel": + fetches["L_p"] = model.L_p + fetches["L_gdl"] = model.L_gdl + fetches["L_GAN"] =model.L_GAN + if model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + fetches["latent_loss"] = model.latent_loss + fetches["recon_loss"] = model.recon_loss + results = sess.run(fetches) + train_losses.append(results["total_loss"]) + #Fetch losses for validation data + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + val_fetches["total_loss"] = model.total_loss + + + if model.__class__.__name__ == "SAVPVideoPredictionModel": + fetches['d_loss'] = model.d_loss + fetches['g_loss'] = model.g_loss + fetches['d_losses'] = model.d_losses + fetches['g_losses'] = model.g_losses + results = sess.run(fetches) + train_losses.append(results["g_losses"]) + val_fetches = {} + #val_fetches["latent_loss"] = model.latent_loss + #For SAVP the total loss is the generator loses + val_fetches["total_loss"] = model.g_losses + + val_fetches["summary"] = model.summary_op + val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval}) + val_losses.append(val_results["total_loss"]) + + summary_writer.add_summary(results["summary"]) + summary_writer.add_summary(val_results["summary"]) + summary_writer.flush() + + # global_step will have the correct step count if we resume from a checkpoint + # global step is read before it's incemented + train_epoch = step/steps_per_epoch + print("progress global step %d epoch %0.1f" % (step + 1, train_epoch)) + if model.__class__.__name__ == "McNetVideoPredictionModel": + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + print ("Total_loss:{}".format(results["total_loss"])) + elif model.__class__.__name__ == "SAVPVideoPredictionModel": + print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"])) + elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"])) + else: + print ("The model name does not exist") + + #print("saving model to", args.output_dir) + saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)# + train_time = time.time() - run_start_time + results_dict = {"train_time":train_time, + "total_steps":total_steps} + save_results_to_dict(results_dict,args.output_dir) + save_results_to_pkl(train_losses, val_losses, args.output_dir) + print("train_losses:",train_losses) + print("val_losses:",val_losses) + plot_train(train_losses,val_losses,args.output_dir) + print("Done") + +if __name__ == '__main__': + main() diff --git a/video_prediction_savp/video_prediction/datasets/__init__.py b/video_prediction_savp/video_prediction/datasets/__init__.py index e449a65bd48b14ef5a11e6846ce4d8f39f7ed193..f58607c2f4c14047aefb36956e98bd228a30aeb1 100644 --- a/video_prediction_savp/video_prediction/datasets/__init__.py +++ b/video_prediction_savp/video_prediction/datasets/__init__.py @@ -7,6 +7,7 @@ from .kth_dataset import KTHVideoDataset from .ucf101_dataset import UCF101VideoDataset from .cartgripper_dataset import CartgripperVideoDataset from .era5_dataset_v2 import ERA5Dataset_v2 +from .moving_mnist import MovingMnist #from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly def get_dataset_class(dataset): @@ -19,6 +20,7 @@ def get_dataset_class(dataset): 'ucf101': 'UCF101VideoDataset', 'cartgripper': 'CartgripperVideoDataset', "era5":"ERA5Dataset_v2", + "moving_mnist":"MovingMnist" # "era5_anomaly":"ERA5Dataset_v2_anomaly", } dataset_class = dataset_mappings.get(dataset, dataset) diff --git a/video_prediction_savp/video_prediction/datasets/kth_dataset.py b/video_prediction_savp/video_prediction/datasets/kth_dataset.py index 40fb6bf57b8219fc7e75c3759df9e6b38fffeb30..e1e11d51968e706868fd89f26faa25d1999d3a9b 100644 --- a/video_prediction_savp/video_prediction/datasets/kth_dataset.py +++ b/video_prediction_savp/video_prediction/datasets/kth_dataset.py @@ -136,11 +136,11 @@ def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequence def main(): parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, help="directory containing the processed directories " + parser.add_argument("input_dir", type=str, help="directory containing the processed directories " "boxing, handclapping, handwaving, " "jogging, running, walking") - parser.add_argument("--output_dir", type=str) - parser.add_argument("--image_size", type=int) + parser.add_argument("output_dir", type=str) + parser.add_argument("image_size", type=int) args = parser.parse_args() partition_names = ['train', 'val', 'test'] diff --git a/video_prediction_savp/video_prediction/datasets/moving_mnist.py b/video_prediction_savp/video_prediction/datasets/moving_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..12a517ac9540dd443866d47ea689803abe9e9a4d --- /dev/null +++ b/video_prediction_savp/video_prediction/datasets/moving_mnist.py @@ -0,0 +1,241 @@ +import argparse +import glob +import itertools +import os +import pickle +import random +import re +import numpy as np +import json +import tensorflow as tf +from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +# ML 2020/04/14: hack for getting functions of process_netCDF_v2: +from os import path +import sys +sys.path.append(path.abspath('../../workflow_parallel_frame_prediction/')) +import DataPreprocess.process_netCDF_v2 +from DataPreprocess.process_netCDF_v2 import get_unique_vars +from DataPreprocess.process_netCDF_v2 import Calc_data_stat +from metadata import MetaData +#from base_dataset import VarLenFeatureVideoDataset +from collections import OrderedDict +from tensorflow.contrib.training import HParams +from mpi4py import MPI +import glob +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + +class MovingMnist(VarLenFeatureVideoDataset): + def __init__(self, *args, **kwargs): + super(MovingMnist, self).__init__(*args, **kwargs) + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + print("features in dataset:",feature.keys()) + self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) + self.image_shape = self.video_shape[1:] + self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape + + def get_default_hparams_dict(self): + default_hparams = super(MovingMnist, self).get_default_hparams_dict() + hparams = dict( + context_frames=10,#Bing: Todo oriignal is 10 + sequence_length=20,#bing: TODO original is 20, + shuffle_on_val=True, + ) + return dict(itertools.chain(default_hparams.items(), hparams.items())) + + + @property + def jpeg_encoding(self): + return False + + + + def num_examples_per_epoch(self): + with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() + sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] + return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + + def filter(self, serialized_example): + return tf.convert_to_tensor(True) + + + def make_dataset_v2(self, batch_size): + def parser(serialized_example): + seqs = OrderedDict() + keys_to_features = { + 'width': tf.FixedLenFeature([], tf.int64), + 'height': tf.FixedLenFeature([], tf.int64), + 'sequence_length': tf.FixedLenFeature([], tf.int64), + 'channels': tf.FixedLenFeature([], tf.int64), + 'images/encoded': tf.VarLenFeature(tf.float32) + } + + # for i in range(20): + # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) + parsed_features = tf.parse_single_example(serialized_example, keys_to_features) + print ("Parse features", parsed_features) + seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) + #width = tf.sparse_tensor_to_dense(parsed_features["width"]) + # height = tf.sparse_tensor_to_dense(parsed_features["height"]) + # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) + # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) + images = [] + print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) + images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + seqs["images"] = images + return seqs + filenames = self.filenames + print ("FILENAMES",filenames) + #TODO: + #temporal_filenames = self.temporal_filenames + shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) + if shuffle: + random.shuffle(filenames) + dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) # todo: what is buffer_size + print("files", self.filenames) + print("mode", self.mode) + dataset = dataset.filter(self.filter) + if shuffle: + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) + else: + dataset = dataset.repeat(self.num_epochs) + + num_parallel_calls = None if shuffle else 1 + dataset = dataset.apply(tf.contrib.data.map_and_batch( + parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) + #dataset = dataset.map(parser) + # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) + # dataset = dataset.apply(tf.contrib.data.map_and_batch( + # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs + dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU + return dataset + + + + def make_batch(self, batch_size): + dataset = self.make_dataset_v2(batch_size) + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + +def save_tf_record(output_fname, sequences): + with tf.python_io.TFRecordWriter(output_fname) as writer: + for i in range(len(sequences)): + sequence = sequences[:,i,:,:,:] + num_frames = len(sequence) + height, width = sequence[0,:,:,0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + features = tf.train.Features(feature={ + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(1), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features=features) + writer.write(example.SerializeToString()) + +def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128 + """ + Read the moving_mnst data which is npz format, and save it to tfrecords files + The shape of dat_npz is [seq_length,number_samples,height,width] + moving_mnst only has one channel + + """ + os.makedirs(output_dir,exist_ok=True) + idx = 0 + num_samples = dat_npz.shape[1] + dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel] + print("data_npz_shape",dat_npz.shape) + dat_npz = dat_npz.astype(np.float32) + dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value + while idx < num_samples - sequences_per_file: + sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:] + output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, sequences) + idx = idx + sequences_per_file + return None + + +def write_sequence_file(output_dir,seq_length,sequences_per_file): + partition_names = ["train","val","test"] + for partition_name in partition_names: + save_output_dir = os.path.join(output_dir,partition_name) + tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) + print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) + sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') + for i in range(tfCounter*sequences_per_file): + sequence_lengths_file.write("%d\n" % seq_length) + sequence_lengths_file.close() + + +def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"): + """ + Plot the seq images + """ + + if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)") + img_len = imgs.shape[0] + fig = plt.figure(figsize=(18,6)) + gs = gridspec.GridSpec(1, 10) + gs.update(wspace = 0., hspace = 0.) + for i in range(img_len): + ax1 = plt.subplot(gs[i]) + plt.imshow(imgs[i] ,cmap = 'jet') + plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) + plt.savefig(os.path.join(output_png_dir, label + "_" + str(idx) + ".jpg")) + print("images_saved") + plt.clf() + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") + parser.add_argument("output_dir", type=str) + parser.add_argument("-sequences_per_file",type=int,default=2) + args = parser.parse_args() + current_path = os.getcwd() + data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy")) + print("data in minist_test_Seq shape",data.shape) + seq_length = data.shape[0] + height = data.shape[2] + width = data.shape[3] + num_samples = data.shape[1] + max_npz = np.max(data) + min_npz = np.min(data) + print("max_npz,",max_npz) + print("min_npz",min_npz) + #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing + dat_train = data[:,:6000,:,:] + dat_val = data[:,6000:7000,:,:] + dat_test = data[:,7000:,:] + #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz") + #save train + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width) + #save val + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width) + #save test + #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width) + #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40) +if __name__ == '__main__': + main() + diff --git a/video_prediction_savp/video_prediction/layers/layer_def.py b/video_prediction_savp/video_prediction/layers/layer_def.py index a59643c7a6d69141134ec01c9b147c4798bfed8e..738e139df0e155ca294fe43edd03a2d79fc1f532 100644 --- a/video_prediction_savp/video_prediction/layers/layer_def.py +++ b/video_prediction_savp/video_prediction/layers/layer_def.py @@ -3,8 +3,8 @@ import tensorflow as tf import numpy as np - weight_decay = 0.0005 + def _activation_summary(x): """Helper to create summaries for activations. Creates a summary that provides a histogram of activations. @@ -157,4 +157,4 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): def bn_layers_wrapper(inputs, is_training): pass - \ No newline at end of file + diff --git a/video_prediction_savp/video_prediction/models/__init__.py b/video_prediction_savp/video_prediction/models/__init__.py index 6d7323f3750949b0ddb411d4a98934928537bc53..b71769a9d8cc523e4f108fc222fc2ed0284019f7 100644 --- a/video_prediction_savp/video_prediction/models/__init__.py +++ b/video_prediction_savp/video_prediction/models/__init__.py @@ -21,7 +21,7 @@ def get_model_class(model): 'vae': 'VanillaVAEVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', 'mcnet': 'McNetVideoPredictionModel', - + 'convLSTM_Loliver': "ConvLstmLoliverVideoPredictionModel" } model_class = model_mappings.get(model, model) model_class = globals().get(model_class) diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py index 846621d8ca1e235c39618951be86fe184a2d974d..df479968325946a9d61896d498428d65692c1848 100644 --- a/video_prediction_savp/video_prediction/models/base_model.py +++ b/video_prediction_savp/video_prediction/models/base_model.py @@ -487,6 +487,7 @@ class VideoPredictionModel(BaseVideoPredictionModel): # skip_vars = {" discriminator_encoder/video_sn_fc4/dense/bias"} if self.num_gpus <= 1: # cpu or 1 gpu + print("self.inputs:>20200822",self.inputs) outputs_tuple, losses_tuple, loss_tuple, metrics_tuple = self.tower_fn(self.inputs) self.outputs, self.eval_outputs = outputs_tuple self.d_losses, self.g_losses, g_losses_post = losses_tuple diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py index c510d050c89908d0e06fe0f1a66e355e61c90530..039533864f34d1608c5a10d4a664d40ce73594a7 100644 --- a/video_prediction_savp/video_prediction/models/savp_model.py +++ b/video_prediction_savp/video_prediction/models/savp_model.py @@ -689,6 +689,10 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell): def generator_given_z_fn(inputs, mode, hparams): # all the inputs needs to have the same length for unrolling the rnn print("inputs.items",inputs.items()) + #20200822 bing + inputs ={"images":inputs["images"]} + print("inputs 20200822:",inputs) + #20200822 inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} cell = SAVPCell(inputs, mode, hparams) diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 744284fc6c5b52bcde249f1a58e04a41e80339fa..01a4f7ce5d6430f19a1e4b99c4cba956b3f7682b 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -94,6 +94,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): def convLSTM_cell(inputs, hidden): y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] # conv lstm cell cell_shape = y_0.get_shape().as_list() channels = cell_shape[-1]