Skip to content
Snippets Groups Projects
Commit 0fcdd9d0 authored by Alex Lee's avatar Alex Lee
Browse files

Fix evaluate and generate scripts to only save future frames (instead of all...

Fix evaluate and generate scripts to only save future frames (instead of all the generated frames). Remove heatmap images and handle images with one channel.
parent 8299e10c
No related branches found
No related tags found
No related merge requests found
...@@ -16,72 +16,17 @@ import tensorflow as tf ...@@ -16,72 +16,17 @@ import tensorflow as tf
from video_prediction import datasets, models from video_prediction import datasets, models
def compute_expectation_np(pix_distrib): def save_image_sequence(prefix_fname, images, time_start_ind=0):
assert pix_distrib.shape[-1] == 1
pix_distrib = pix_distrib / np.sum(pix_distrib, axis=(-3, -2), keepdims=True)
height, width = pix_distrib.shape[-3:-1]
xv, yv = np.meshgrid(np.arange(width), np.arange(height))
return np.stack([np.sum(yv[:, :, None] * pix_distrib, axis=(-3, -2, -1)),
np.sum(xv[:, :, None] * pix_distrib, axis=(-3, -2, -1))], axis=-1)
def as_heatmap(image, normalize=True):
import matplotlib.pyplot as plt
image = np.squeeze(image, axis=-1)
if normalize:
image = image / np.max(image, axis=(-2, -1), keepdims=True)
cmap = plt.get_cmap('viridis')
heatmap = cmap(image)[..., :3]
return heatmap
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
def resize_and_draw_circle(image, size, center, radius, dpi=128.0, **kwargs):
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import io
height, width = size
fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(image, interpolation='none')
circle = Circle(center[::-1], radius=radius, **kwargs)
ax.add_patch(circle)
ax.axis("off")
fig.canvas.draw()
trans = ax.figure.dpi_scale_trans.inverted()
bbox = ax.bbox.transformed(trans)
buff = io.BytesIO()
plt.savefig(buff, format="png", dpi=ax.figure.dpi, bbox_inches=bbox)
buff.seek(0)
image = plt.imread(buff)[..., :3]
plt.close(fig)
return image
def save_image_sequence(prefix_fname, images, overlaid_images=None, centers=None,
radius=5, alpha=0.8, time_start_ind=0):
import cv2 import cv2
head, tail = os.path.split(prefix_fname) head, tail = os.path.split(prefix_fname)
if head and not os.path.exists(head): if head and not os.path.exists(head):
os.makedirs(head) os.makedirs(head)
if images.shape[-1] == 1:
images = as_heatmap(images)
if overlaid_images is not None:
assert images.shape[-1] == 3
assert overlaid_images.shape[-1] == 1
gray_images = rgb2gray(images)
overlaid_images = as_heatmap(overlaid_images)
images = (1 - alpha) * gray_images[..., None] + alpha * overlaid_images
for t, image in enumerate(images): for t, image in enumerate(images):
image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t) image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t)
if centers is not None:
scale = np.max(np.array([256, 256]) / np.array(image.shape[:2]))
image = resize_and_draw_circle(image, np.array(image.shape[:2]) * scale, centers[t], radius,
edgecolor='r', fill=False, linestyle='--', linewidth=2)
image = (image * 255.0).astype(np.uint8) image = (image * 255.0).astype(np.uint8)
if image.shape[-1] == 1:
image = np.tile(image, (1, 1, 3))
else:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(image_fname, image) cv2.imwrite(image_fname, image)
...@@ -139,7 +84,10 @@ def merge_hparams(hparams0, hparams1): ...@@ -139,7 +84,10 @@ def merge_hparams(hparams0, hparams1):
def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None): def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
sequence_length = model_hparams.sequence_length
context_frames = model_hparams.context_frames context_frames = model_hparams.context_frames
future_length = sequence_length - context_frames
context_images = results['images'][:, :context_frames] context_images = results['images'][:, :context_frames]
if 'eval_diversity' in results: if 'eval_diversity' in results:
...@@ -159,6 +107,8 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ ...@@ -159,6 +107,8 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_
for metric_name in metric_names: for metric_name in metric_names:
subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask) subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images')) gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
# only keep the future frames
gen_images = gen_images[:, -future_length:]
metric = results['eval_%s/%s' % (metric_name, subtask)] metric = results['eval_%s/%s' % (metric_name, subtask)]
save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
metric, sample_start_ind=sample_start_ind) metric, sample_start_ind=sample_start_ind)
......
#!/usr/bin/env bash
\ No newline at end of file
...@@ -115,6 +115,10 @@ def main(): ...@@ -115,6 +115,10 @@ def main():
hparams_dict=hparams_dict, hparams_dict=hparams_dict,
hparams=args.model_hparams) hparams=args.model_hparams)
sequence_length = model.hparams.sequence_length
context_frames = model.hparams.context_frames
future_length = sequence_length - context_frames
if args.num_samples: if args.num_samples:
if args.num_samples > dataset.num_examples_per_epoch(): if args.num_samples > dataset.num_examples_per_epoch():
raise ValueError('num_samples cannot be larger than the dataset') raise ValueError('num_samples cannot be larger than the dataset')
...@@ -159,6 +163,8 @@ def main(): ...@@ -159,6 +163,8 @@ def main():
feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
for stochastic_sample_ind in range(args.num_stochastic_samples): for stochastic_sample_ind in range(args.num_stochastic_samples):
gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
# only keep the future frames
gen_images = gen_images[:, -future_length:]
for i, gen_images_ in enumerate(gen_images): for i, gen_images_ in enumerate(gen_images):
gen_images_ = (gen_images_ * 255.0).astype(np.uint8) gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment