From 4386224ebe4522e76c9e245a3c1708974a45a99a Mon Sep 17 00:00:00 2001 From: Alex Lee <alexleegk@gmail.com> Date: Wed, 23 Jan 2019 11:46:52 -0800 Subject: [PATCH] Fix evaluate and generate scripts. --- scripts/evaluate.py | 1 + scripts/generate.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 88d44d15..2f7f4c3e 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -231,6 +231,7 @@ def main(): 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel( + mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, diff --git a/scripts/generate.py b/scripts/generate.py index 534eb5fa..44b4b42a 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -112,6 +112,7 @@ def main(): 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel( + mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams) @@ -166,11 +167,15 @@ def main(): # only keep the future frames gen_images = gen_images[:, -future_length:] for i, gen_images_ in enumerate(gen_images): + context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + if args.gif_length: + context_and_gen_images = context_and_gen_images[:args.gif_length] save_gif(os.path.join(args.output_gif_dir, gen_images_fname), - gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps) + context_and_gen_images, fps=args.fps) gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_): -- GitLab