diff --git a/scripts/generate_transfer_learning_finetune.py b/scripts/generate_transfer_learning_finetune.py index 662a415d69ca5f8cf8cdd5e9c800a1403976a551..c4fa831594910b3389a873cc9f8d4dd87944d66e 100644 --- a/scripts/generate_transfer_learning_finetune.py +++ b/scripts/generate_transfer_learning_finetune.py @@ -214,6 +214,8 @@ def main(): print("Sample id", sample_ind) if sample_ind <= 24: pass + elif sample_ind >= len(X_test): + break else: gen_images_stochastic = [] if args.num_samples and sample_ind >= args.num_samples: @@ -239,7 +241,6 @@ def main(): #bing:20200417 t_stampe = test_temporal_pkl[sample_ind+i] print("timestamp:",type(t_stampe)) - persistent_ts = np.array(t_stampe) - datetime.timedelta(days=1) print ("persistent ts",persistent_ts) persistent_idx = list(test_temporal_pkl).index(np.array(persistent_ts)) @@ -323,25 +324,25 @@ def main(): plt.savefig(os.path.join(args.output_png_dir, "Persistent_Sample_" + str(name) + ".jpg")) plt.clf() -## -## with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: -## pickle.dump(list(persistent_images_all), input_files) -## print ("Save persistent all") -## if is_first: -## gen_images_all = gen_images_stochastic -## is_first = False -## else: -## gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) -## -## if args.num_stochastic_samples == 1: -## with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: -## pickle.dump(list(gen_images_all[0]), gen_files) -## print ("Save generate all") -## else: -## with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: -## pickle.dump(list(gen_images_stochastic), gen_files) -## with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: -## pickle.dump(list(gen_images_all), gen_files) + + with open(os.path.join(args.output_png_dir, "persistent_images_all.pkl"), "wb") as input_files: + pickle.dump(list(persistent_images_all), input_files) + print ("Save persistent all") + if is_first: + gen_images_all = gen_images_stochastic + is_first = False + else: + gen_images_all = np.concatenate((np.array(gen_images_all), np.array(gen_images_stochastic)), axis=1) + + if args.num_stochastic_samples == 1: + with open(os.path.join(args.output_png_dir, "gen_images_all.pkl"), "wb") as gen_files: + pickle.dump(list(gen_images_all[0]), gen_files) + print ("Save generate all") + else: + with open(os.path.join(args.output_png_dir, "gen_images_sample_id_" + str(sample_ind)),"wb") as gen_files: + pickle.dump(list(gen_images_stochastic), gen_files) + with open(os.path.join(args.output_png_dir, "gen_images_all_stochastic"), "wb") as gen_files: + pickle.dump(list(gen_images_all), gen_files) ## ## ## # fig = plt.figure()