diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 24d84fb6105a6b4a3485680991d9fa46b3b1cb9d..b9fd44a3476b72035e1608c2f849342d7bec9964 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -16,73 +16,18 @@ import tensorflow as tf from video_prediction import datasets, models -def compute_expectation_np(pix_distrib): - assert pix_distrib.shape[-1] == 1 - pix_distrib = pix_distrib / np.sum(pix_distrib, axis=(-3, -2), keepdims=True) - height, width = pix_distrib.shape[-3:-1] - xv, yv = np.meshgrid(np.arange(width), np.arange(height)) - return np.stack([np.sum(yv[:, :, None] * pix_distrib, axis=(-3, -2, -1)), - np.sum(xv[:, :, None] * pix_distrib, axis=(-3, -2, -1))], axis=-1) - - -def as_heatmap(image, normalize=True): - import matplotlib.pyplot as plt - image = np.squeeze(image, axis=-1) - if normalize: - image = image / np.max(image, axis=(-2, -1), keepdims=True) - cmap = plt.get_cmap('viridis') - heatmap = cmap(image)[..., :3] - return heatmap - - -def rgb2gray(rgb): - return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) - - -def resize_and_draw_circle(image, size, center, radius, dpi=128.0, **kwargs): - import matplotlib.pyplot as plt - from matplotlib.patches import Circle - import io - height, width = size - fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) - ax = fig.add_axes([0, 0, 1, 1]) - ax.imshow(image, interpolation='none') - circle = Circle(center[::-1], radius=radius, **kwargs) - ax.add_patch(circle) - ax.axis("off") - fig.canvas.draw() - trans = ax.figure.dpi_scale_trans.inverted() - bbox = ax.bbox.transformed(trans) - buff = io.BytesIO() - plt.savefig(buff, format="png", dpi=ax.figure.dpi, bbox_inches=bbox) - buff.seek(0) - image = plt.imread(buff)[..., :3] - plt.close(fig) - return image - - -def save_image_sequence(prefix_fname, images, overlaid_images=None, centers=None, - radius=5, alpha=0.8, time_start_ind=0): +def save_image_sequence(prefix_fname, images, time_start_ind=0): import cv2 head, tail = os.path.split(prefix_fname) if head and not os.path.exists(head): os.makedirs(head) - if images.shape[-1] == 1: - images = as_heatmap(images) - if overlaid_images is not None: - assert images.shape[-1] == 3 - assert overlaid_images.shape[-1] == 1 - gray_images = rgb2gray(images) - overlaid_images = as_heatmap(overlaid_images) - images = (1 - alpha) * gray_images[..., None] + alpha * overlaid_images for t, image in enumerate(images): image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t) - if centers is not None: - scale = np.max(np.array([256, 256]) / np.array(image.shape[:2])) - image = resize_and_draw_circle(image, np.array(image.shape[:2]) * scale, centers[t], radius, - edgecolor='r', fill=False, linestyle='--', linewidth=2) image = (image * 255.0).astype(np.uint8) - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if image.shape[-1] == 1: + image = np.tile(image, (1, 1, 3)) + else: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(image_fname, image) @@ -139,7 +84,10 @@ def merge_hparams(hparams0, hparams1): def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None): + sequence_length = model_hparams.sequence_length context_frames = model_hparams.context_frames + future_length = sequence_length - context_frames + context_images = results['images'][:, :context_frames] if 'eval_diversity' in results: @@ -159,6 +107,8 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ for metric_name in metric_names: subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask) gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images')) + # only keep the future frames + gen_images = gen_images[:, -future_length:] metric = results['eval_%s/%s' % (metric_name, subtask)] save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), metric, sample_start_ind=sample_start_ind) diff --git a/scripts/evaluate_svg.sh b/scripts/evaluate_svg.sh new file mode 100644 index 0000000000000000000000000000000000000000..212c4ba239ecdbd15c70c05b9336c32175dc8c5c --- /dev/null +++ b/scripts/evaluate_svg.sh @@ -0,0 +1 @@ +#!/usr/bin/env bash \ No newline at end of file diff --git a/scripts/generate.py b/scripts/generate.py index 55eab9d221553d65190f7ba25d344b86171c6fa2..534eb5fa902341e3ca65a1f733f2c5f841c84210 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -115,6 +115,10 @@ def main(): hparams_dict=hparams_dict, hparams=args.model_hparams) + sequence_length = model.hparams.sequence_length + context_frames = model.hparams.context_frames + future_length = sequence_length - context_frames + if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): raise ValueError('num_samples cannot be larger than the dataset') @@ -159,6 +163,8 @@ def main(): feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} for stochastic_sample_ind in range(args.num_stochastic_samples): gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) + # only keep the future frames + gen_images = gen_images[:, -future_length:] for i, gen_images_ in enumerate(gen_images): gen_images_ = (gen_images_ * 255.0).astype(np.uint8)