diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 88d44d1529e4d6e5e57d711f377733ea79b355bf..2f7f4c3e57b88f22d504bf4fc664dfae7a483243 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -231,6 +231,7 @@ def main(): 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel( + mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams, eval_num_samples=args.num_stochastic_samples, diff --git a/scripts/generate.py b/scripts/generate.py index 534eb5fa902341e3ca65a1f733f2c5f841c84210..44b4b42a2100746ab508ba275dfa3c4fc6f9ebf5 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -112,6 +112,7 @@ def main(): 'repeat': dataset.hparams.time_shift, }) model = VideoPredictionModel( + mode=args.mode, hparams_dict=hparams_dict, hparams=args.model_hparams) @@ -166,11 +167,15 @@ def main(): # only keep the future frames gen_images = gen_images[:, -future_length:] for i, gen_images_ in enumerate(gen_images): + context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) + context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) + if args.gif_length: + context_and_gen_images = context_and_gen_images[:args.gif_length] save_gif(os.path.join(args.output_gif_dir, gen_images_fname), - gen_images_[:args.gif_length] if args.gif_length else gen_images_, fps=args.fps) + context_and_gen_images, fps=args.fps) gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_):