diff --git a/scripts/Analysis_all.py b/scripts/Analysis_all.py index 65a6000a7bff2047250b73b4035a78a1c4e9d136..6d61689bf86e4e6e6ba5ac5fd23f944a33c55eee 100644 --- a/scripts/Analysis_all.py +++ b/scripts/Analysis_all.py @@ -8,7 +8,6 @@ from matplotlib.pylab import plt # model_names = ["SAVP","SAVP_Finetune","GAN","VAE"] - # results_path = ["results_test_samples/era5_size_64_64_3_norm_dup_pretrained/ours_savp","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_savp/ours_savp", # "results_test_samples/era5_size_64_64_3_norm_dup_pretrained_gan/kth_ours_gan","results_test_samples/era5_size_64_64_3_norm_msl_gph_pretrained_gan/kth_ours_gan"] # @@ -38,6 +37,7 @@ for path in results_path: psnr_all.append(psnr) ssim_all.append(ssim) + def get_metric(metrtic): if metric == "mse": return mse_all diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index f3795b4230b2736cd122306da4b96ca9e41b067f..80470df25b5db6854c15d343bf2e62c596b40cb5 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -79,7 +79,7 @@ def main(): parser.add_argument("--num_samples", type = int, help = "number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type = int, default = 1) - parser.add_argument("--num_stochastic_samples", type = int, default = 5) + parser.add_argument("--num_stochastic_samples", type = int, default = 1) parser.add_argument("--gif_length", type = int, help = "default is sequence_length") parser.add_argument("--fps", type = int, default = 4) @@ -360,15 +360,24 @@ def main(): # else: # pass - # with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)), "wb") as gen_files: - # pickle.dump(list(gen_images_stochastic), gen_files) + # # if sample_ind == 0: # gen_images_all = gen_images_stochastic # else: - # gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis = 1) - # with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files: - # pickle.dump(list(gen_images_all), gen_files) + # gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + # + # if args.num_stochastic_samples == 1: + # with open(os.path.join(args.output_png_dir, "gen_images_all"), "wb") as gen_files: + # pickle.dump(list(gen_images_all[0]), gen_files) + # else: + # with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: + # pickle.dump(list(gen_images_stochastic), gen_files) + # with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: + # pickle.dump(list(gen_images_all), gen_files) + # + # + # # # sample_ind += args.batch_size @@ -579,7 +588,7 @@ def main(): # plt.savefig(os.path.join(args.output_png_dir, "kdp_gen_images.png"), dpi = 400) # plt.clf() - # #line plot for evaluating the prediction and groud-truth + #line plot for evaluating the prediction and groud-truth # for i in [0,3,6,9,12,15,18]: # fig = plt.figure() # plt.scatter(gen_images_all[:,i,:,:][s].flatten(),input_images_all[:,i,:,:][s].flatten(),s=0.3) diff --git a/scripts/train_v2.py b/scripts/train_v2.py index 1f99a3193ccf12ffb15f75946991548c43b6c278..f83192d6af7d3953c666f40cc9e6d3766a78e92e 100644 --- a/scripts/train_v2.py +++ b/scripts/train_v2.py @@ -150,8 +150,8 @@ def main(): VideoPredictionModel = models.get_model_class(args.model) hparams_dict = dict(model_hparams_dict) hparams_dict.update({ - 'context_frames': train_dataset.hparams.context_frames,#Bing: TODO what is context_frames? - 'sequence_length': train_dataset.hparams.sequence_length,#Bing: TODO what is sequence_frames + 'context_frames': train_dataset.hparams.context_frames, + 'sequence_length': train_dataset.hparams.sequence_length, 'repeat': train_dataset.hparams.time_shift, }) model = VideoPredictionModel(