From 08bb6f671877cf286bb75f873aa860ffe801b3ed Mon Sep 17 00:00:00 2001 From: "b.gong" <b.gong@fz-juelich.de> Date: Mon, 18 May 2020 12:22:35 +0200 Subject: [PATCH] update generate_transfer_learning_finetune.py --- .../generate_transfer_learning_finetune.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index 662a415d..c4fa8315 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -214,6 +214,8 @@ def main(): print("Sample id", sample_ind) if sample_ind <= 24: pass + elif sample_ind >= len(X_test): + break else: gen_images_stochastic = [] if args.num_samples and sample_ind >= args.num_samples: @@ -239,7 +241,6 @@ def main(): #bing:20200417 t_stampe = test_temporal_pkl[sample_ind+i] print("timestamp:",type(t_stampe)) - persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) print ("persistent ts",persistent_ts) persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) @@ -323,25 +324,25 @@ def main(): plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) plt.clf() -## -## with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: -## pickle.dump(list(persistent_images_all), input_files) -## print ("Save persistent all") -## if is_first: -## gen_images_all = gen_images_stochastic -## is_first = False -## else: -## gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) -## -## if args.num_stochastic_samples == 1: -## with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: -## pickle.dump(list(gen_images_all[0]), gen_files) -## print ("Save generate all") -## else: -## with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: -## pickle.dump(list(gen_images_stochastic), gen_files) -## with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: -## pickle.dump(list(gen_images_all), gen_files) + + with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: + pickle.dump(list(persistent_images_all), input_files) + print ("Save persistent all") + if is_first: + gen_images_all = gen_images_stochastic + is_first = False + else: + gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + + if args.num_stochastic_samples == 1: + with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: + pickle.dump(list(gen_images_all[0]), gen_files) + print ("Save generate all") + else: + with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: + pickle.dump(list(gen_images_stochastic), gen_files) + with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: + pickle.dump(list(gen_images_all), gen_files) ## ## ## # fig = plt.figure() -- GitLab