Skip to content
Snippets Groups Projects
Commit 0fcdd9d0 authored by Alex Lee's avatar Alex Lee
Browse files

Fix evaluate and generate scripts to only save future frames (instead of all...

Fix evaluate and generate scripts to only save future frames (instead of all the generated frames). Remove heatmap images and handle images with one channel.
parent 8299e10c
No related branches found
No related tags found
No related merge requests found
......@@ -16,72 +16,17 @@ import tensorflow as tf
from video_prediction import datasets, models
def compute_expectation_np(pix_distrib):
assert pix_distrib.shape[-1] == 1
pix_distrib = pix_distrib / np.sum(pix_distrib, axis=(-3, -2), keepdims=True)
height, width = pix_distrib.shape[-3:-1]
xv, yv = np.meshgrid(np.arange(width), np.arange(height))
return np.stack([np.sum(yv[:, :, None] * pix_distrib, axis=(-3, -2, -1)),
np.sum(xv[:, :, None] * pix_distrib, axis=(-3, -2, -1))], axis=-1)
def as_heatmap(image, normalize=True):
import matplotlib.pyplot as plt
image = np.squeeze(image, axis=-1)
if normalize:
image = image / np.max(image, axis=(-2, -1), keepdims=True)
cmap = plt.get_cmap('viridis')
heatmap = cmap(image)[..., :3]
return heatmap
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
def resize_and_draw_circle(image, size, center, radius, dpi=128.0, **kwargs):
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import io
height, width = size
fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi)
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(image, interpolation='none')
circle = Circle(center[::-1], radius=radius, **kwargs)
ax.add_patch(circle)
ax.axis("off")
fig.canvas.draw()
trans = ax.figure.dpi_scale_trans.inverted()
bbox = ax.bbox.transformed(trans)
buff = io.BytesIO()
plt.savefig(buff, format="png", dpi=ax.figure.dpi, bbox_inches=bbox)
buff.seek(0)
image = plt.imread(buff)[..., :3]
plt.close(fig)
return image
def save_image_sequence(prefix_fname, images, overlaid_images=None, centers=None,
radius=5, alpha=0.8, time_start_ind=0):
def save_image_sequence(prefix_fname, images, time_start_ind=0):
import cv2
head, tail = os.path.split(prefix_fname)
if head and not os.path.exists(head):
os.makedirs(head)
if images.shape[-1] == 1:
images = as_heatmap(images)
if overlaid_images is not None:
assert images.shape[-1] == 3
assert overlaid_images.shape[-1] == 1
gray_images = rgb2gray(images)
overlaid_images = as_heatmap(overlaid_images)
images = (1 - alpha) * gray_images[..., None] + alpha * overlaid_images
for t, image in enumerate(images):
image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t)
if centers is not None:
scale = np.max(np.array([256, 256]) / np.array(image.shape[:2]))
image = resize_and_draw_circle(image, np.array(image.shape[:2]) * scale, centers[t], radius,
edgecolor='r', fill=False, linestyle='--', linewidth=2)
image = (image * 255.0).astype(np.uint8)
if image.shape[-1] == 1:
image = np.tile(image, (1, 1, 3))
else:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(image_fname, image)
......@@ -139,7 +84,10 @@ def merge_hparams(hparams0, hparams1):
def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None):
sequence_length = model_hparams.sequence_length
context_frames = model_hparams.context_frames
future_length = sequence_length - context_frames
context_images = results['images'][:, :context_frames]
if 'eval_diversity' in results:
......@@ -159,6 +107,8 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_
for metric_name in metric_names:
subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask)
gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images'))
# only keep the future frames
gen_images = gen_images[:, -future_length:]
metric = results['eval_%s/%s' % (metric_name, subtask)]
save_metrics(os.path.join(subtask_dir, 'metrics', metric_name),
metric, sample_start_ind=sample_start_ind)
......
#!/usr/bin/env bash
\ No newline at end of file
......@@ -115,6 +115,10 @@ def main():
hparams_dict=hparams_dict,
hparams=args.model_hparams)
sequence_length = model.hparams.sequence_length
context_frames = model.hparams.context_frames
future_length = sequence_length - context_frames
if args.num_samples:
if args.num_samples > dataset.num_examples_per_epoch():
raise ValueError('num_samples cannot be larger than the dataset')
......@@ -159,6 +163,8 @@ def main():
feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
for stochastic_sample_ind in range(args.num_stochastic_samples):
gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict)
# only keep the future frames
gen_images = gen_images[:, -future_length:]
for i, gen_images_ in enumerate(gen_images):
gen_images_ = (gen_images_ * 255.0).astype(np.uint8)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment