diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 8de2a58b006c1c71d1f9e58d0eaba6ce4b1a23d3..24d84fb6105a6b4a3485680991d9fa46b3b1cb9d 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re import argparse import csv import errno @@ -11,10 +12,8 @@ import random import numpy as np import tensorflow as tf -from tensorflow.python.util import nest -from video_prediction import datasets, models, metrics -from video_prediction.policies.servo_policy import ServoPolicy +from video_prediction import datasets, models def compute_expectation_np(pix_distrib): @@ -142,23 +141,25 @@ def merge_hparams(hparams0, hparams1): def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None): context_frames = model_hparams.context_frames context_images = results['images'][:, :context_frames] - images = results['eval_images'] - metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'ssim_mcnet', 'vgg_csim'] - metric_fns = [metrics.peak_signal_to_noise_ratio_np, - metrics.structural_similarity_np, - metrics.structural_similarity_scikit_np, - metrics.structural_similarity_finn_np, - metrics.structural_similarity_mcnet_np, - None] + + if 'eval_diversity' in results: + metric = results['eval_diversity'] + metric_name = 'diversity' + subtask_dir = task_dir + '_%s' % metric_name + save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), + metric, sample_start_ind=sample_start_ind) + subtasks = subtasks or ['max'] - for metric_name, metric_fn in zip(metric_names, metric_fns): - for subtask in subtasks: + for subtask in subtasks: + metric_names = [] + for k in results.keys(): + if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k): + m = re.match('eval_(\w+)/%s' % subtask, k) + metric_names.append(m.group(1)) + for metric_name in metric_names: subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask) gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images')) - if metric_fn is not None: # recompute using numpy implementation - metric = metric_fn(images, gen_images, keep_axis=(0, 1)) - else: - metric = results['eval_%s/%s' % (metric_name, subtask)] + metric = results['eval_%s/%s' % (metric_name, subtask)] save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), metric, sample_start_ind=sample_start_ind) if only_metrics: @@ -170,133 +171,28 @@ def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ gen_images, sample_start_ind=sample_start_ind) -def save_prediction_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False): - context_frames = model_hparams.context_frames - sequence_length = model_hparams.sequence_length - context_images, images = np.split(results['images'], [context_frames], axis=1) - gen_images = results['gen_images'][:, context_frames - sequence_length:] - psnr = metrics.peak_signal_to_noise_ratio_np(images, gen_images, keep_axis=(0, 1)) - mse = metrics.mean_squared_error_np(images, gen_images, keep_axis=(0, 1)) - ssim = metrics.structural_similarity_np(images, gen_images, keep_axis=(0, 1)) - save_metrics(os.path.join(task_dir, 'metrics', 'psnr'), - psnr, sample_start_ind=sample_start_ind) - save_metrics(os.path.join(task_dir, 'metrics', 'mse'), - mse, sample_start_ind=sample_start_ind) - save_metrics(os.path.join(task_dir, 'metrics', 'ssim'), - ssim, sample_start_ind=sample_start_ind) - if only_metrics: - return - - save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'), - context_images, sample_start_ind=sample_start_ind) - save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'), - gen_images, sample_start_ind=sample_start_ind) - - -def save_motion_results(task_dir, results, model_hparams, draw_center=False, - sample_start_ind=0, only_metrics=False): - context_frames = model_hparams.context_frames - sequence_length = model_hparams.sequence_length - pix_distribs = results['pix_distribs'][:, context_frames:] - gen_pix_distribs = results['gen_pix_distribs'][:, context_frames - sequence_length:] - pix_dist = metrics.expected_pixel_distance_np(pix_distribs, gen_pix_distribs, keep_axis=(0, 1)) - save_metrics(os.path.join(task_dir, 'metrics', 'pix_dist'), - pix_dist, sample_start_ind=sample_start_ind) - if only_metrics: - return - - context_images, images = np.split(results['images'], [context_frames], axis=1) - gen_images = results['gen_images'][:, context_frames - sequence_length:] - initial_pix_distrib = results['pix_distribs'][:, 0:1] - num_motions = pix_distribs.shape[-1] - for i in range(num_motions): - output_name_posfix = '%d' % i if num_motions > 1 else '' - centers = compute_expectation_np(initial_pix_distrib[..., i:i + 1]) if draw_center else None - save_image_sequences(os.path.join(task_dir, 'inputs', 'pix_distrib%s' % output_name_posfix), - context_images[:, 0:1], initial_pix_distrib[..., i:i + 1], centers, sample_start_ind=sample_start_ind) - centers = compute_expectation_np(gen_pix_distribs[..., i:i + 1]) if draw_center else None - save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_pix_distrib%s' % output_name_posfix), - gen_images, gen_pix_distribs[..., i:i + 1], centers, sample_start_ind=sample_start_ind) - - -def save_servo_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False): - context_frames = model_hparams.context_frames - sequence_length = model_hparams.sequence_length - context_images, images = np.split(results['images'], [context_frames], axis=1) - gen_images = results['gen_images'][:, context_frames - sequence_length:] - goal_image = results['goal_image'] - # TODO: should exclude "context" actions assuming that they are passed in to the network - actions = results['actions'] - gen_actions = results['gen_actions'] - goal_image_mse = metrics.mean_squared_error_np(goal_image, gen_images[:, -1], keep_axis=0) - action_mse = metrics.mean_squared_error_np(actions, gen_actions, keep_axis=(0, 1)) - save_metrics(os.path.join(task_dir, 'metrics', 'goal_image_mse'), - goal_image_mse[:, None], sample_start_ind=sample_start_ind) - save_metrics(os.path.join(task_dir, 'metrics', 'action_mse'), - action_mse, sample_start_ind=sample_start_ind) - if only_metrics: - return - - save_image_sequences(os.path.join(task_dir, 'inputs', 'context_image'), - context_images, sample_start_ind=sample_start_ind) - save_image_sequences(os.path.join(task_dir, 'inputs', 'goal_image'), - goal_image[:, None], sample_start_ind=sample_start_ind) - save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image'), - gen_images, sample_start_ind=sample_start_ind) - gen_image_goal_diffs = np.abs(gen_images - goal_image[:, None]) - save_image_sequences(os.path.join(task_dir, 'outputs', 'gen_image_goal_diff'), - gen_image_goal_diffs, sample_start_ind=sample_start_ind) - - def main(): """ results_dir ├── output_dir # condition / method - │ ├── prediction # task + │ ├── prediction_eval_lpips_max # task: best sample in terms of LPIPS similarity │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs - │ │ │ ├── gen_image_00000_00.png # predicted images (only the ones in the loss) + │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) │ │ │ └── ... │ │ └── metrics - │ │ ├── psnr.csv - │ │ ├── mse.csv - │ │ └── ssim.csv - │ ├── prediction_eval_vgg_csim_max # task: best sample in terms of VGG cosine similarity + │ │ └── lpips.csv + │ ├── prediction_eval_ssim_max # task: best sample in terms of SSIM │ │ ├── inputs │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step │ │ │ └── ... │ │ ├── outputs - │ │ │ ├── gen_image_00000_00.png # predicted images (only the ones in the loss) - │ │ │ └── ... - │ │ └── metrics - │ │ └── vgg_csim.csv - │ ├── servo - │ │ ├── inputs - │ │ │ ├── context_image_00000_00.png - │ │ │ ├── ... - │ │ │ ├── goal_image_00000_00.png # only one goal image per sample - │ │ │ └── ... - │ │ ├── outputs - │ │ │ ├── gen_image_00000_00.png - │ │ │ ├── ... - │ │ │ ├── gen_image_goal_diff_00000_00.png - │ │ │ └── ... - │ │ └── metrics - │ │ ├── action_mse.csv - │ │ └── goal_image_mse.csv - │ ├── motion - │ │ ├── inputs - │ │ │ ├── pix_distrib_00000_00.png - │ │ │ └── ... - │ │ ├── outputs - │ │ │ ├── gen_pix_distrib_00000_00.png - │ │ │ ├── ... - │ │ │ ├── gen_pix_distrib_overlaid_00000_00.png + │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) │ │ │ └── ... │ │ └── metrics - │ │ └── pix_dist.csv + │ │ └── ssim.csv │ └── ... └── ... """ @@ -320,8 +216,7 @@ def main(): parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") parser.add_argument("--num_epochs", type=int, default=1) - parser.add_argument("--tasks", type=str, nargs='+', help='tasks to evaluate (e.g. prediction, prediction_eval, servo, motion)') - parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'min'], help='subtasks to evaluate (e.g. max, avg, min). only applicable to prediction_eval') + parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') parser.add_argument("--only_metrics", action='store_true') parser.add_argument("--num_stochastic_samples", type=int, default=100) @@ -343,10 +238,10 @@ def main(): model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) @@ -360,7 +255,6 @@ def main(): try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) - model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1]) @@ -377,21 +271,26 @@ def main(): print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, - hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) - - def override_hparams_dict(dataset): - hparams_dict = dict(model_hparams_dict) - hparams_dict['context_frames'] = dataset.hparams.context_frames - hparams_dict['sequence_length'] = dataset.hparams.sequence_length - hparams_dict['repeat'] = dataset.hparams.time_shift - return hparams_dict + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) - model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams, - eval_num_samples=args.num_stochastic_samples, eval_parallel_iterations=args.eval_parallel_iterations) - context_frames = model.hparams.context_frames - sequence_length = model.hparams.sequence_length + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams, + eval_num_samples=args.num_stochastic_samples, + eval_parallel_iterations=args.eval_parallel_iterations) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): @@ -400,44 +299,12 @@ def main(): else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: - raise ValueError('batch_size should evenly divide the dataset') - - inputs, target = dataset.make_batch(args.batch_size) - if not isinstance(model, models.GroundTruthVideoPredictionModel): - # remove ground truth data past context_frames to prevent accidentally using it - for k, v in inputs.items(): - if k != 'actions': - inputs[k] = v[:, :context_frames] + raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + inputs = dataset.make_batch(args.batch_size) input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} - target_ph = tf.placeholder(target.dtype, target.shape, 'targets_ph') - with tf.variable_scope(''): - model.build_graph(input_phs, target_ph) - - tasks = args.tasks - if tasks is None: - tasks = ['prediction_eval'] - if 'pix_distribs' in inputs: - tasks.append('motion') - - if 'servo' in tasks: - servo_model = VideoPredictionModel(mode='test', hparams_dict=model_hparams_dict, hparams=args.model_hparams) - cem_batch_size = 200 - plan_horizon = sequence_length - 1 - image_shape = inputs['images'].shape.as_list()[2:] - state_shape = inputs['states'].shape.as_list()[2:] - action_shape = inputs['actions'].shape.as_list()[2:] - servo_input_phs = { - 'images': tf.placeholder(tf.float32, shape=[cem_batch_size, context_frames] + image_shape), - 'states': tf.placeholder(tf.float32, shape=[cem_batch_size, 1] + state_shape), - 'actions': tf.placeholder(tf.float32, shape=[cem_batch_size, plan_horizon] + action_shape), - } - if isinstance(servo_model, models.GroundTruthVideoPredictionModel): - images_shape = inputs['images'].shape.as_list()[1:] - servo_input_phs['images'] = tf.placeholder(tf.float32, shape=[cem_batch_size] + images_shape) - with tf.variable_scope('', reuse=True): - servo_model.build_graph(servo_input_phs) + model.build_graph(input_phs) output_dir = args.output_dir if not os.path.exists(output_dir): @@ -452,156 +319,37 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) + sess.graph.as_default() model.restore(sess, args.checkpoint) - if 'servo' in tasks: - servo_policy = ServoPolicy(servo_model, sess) - sample_ind = 0 while True: if args.num_samples and sample_ind >= args.num_samples: break try: - input_results, target_result = sess.run([inputs, target]) + input_results = sess.run(inputs) except tf.errors.OutOfRangeError: break print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) - if 'prediction_eval' in tasks: - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - feed_dict.update({target_ph: target_result}) - # compute "best" metrics using the computation graph (if available) or explicitly with python logic - if model.eval_outputs and model.eval_metrics: - fetches = {'images': model.inputs['images']} - fetches.update(model.eval_outputs.items()) - fetches.update(model.eval_metrics.items()) - results = sess.run(fetches, feed_dict=feed_dict) - else: - metric_names = ['psnr', 'ssim', 'ssim_scikit', 'ssim_finn', 'ssim_mcnet', 'vgg_csim'] - metric_fns = [metrics.peak_signal_to_noise_ratio_np, - metrics.structural_similarity_np, - metrics.structural_similarity_scikit_np, - metrics.structural_similarity_finn_np, - metrics.structural_similarity_mcnet_np, - metrics.vgg_cosine_similarity_np] - - all_gen_images = [] - all_metrics = [np.empty((args.num_stochastic_samples, args.batch_size, sequence_length - context_frames)) for _ in metric_names] - for s in range(args.num_stochastic_samples): - gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) - all_gen_images.append(gen_images) - for metric_name, metric_fn, all_metric in zip(metric_names, metric_fns, all_metrics): - metric = metric_fn(gen_images, target_result, keep_axis=(0, 1)) - all_metric[s] = metric - - results = {'images': input_results['images'], - 'eval_images': target_result} - for metric_name, all_metric in zip(metric_names, all_metrics): - for subtask in args.eval_substasks: - results['eval_gen_images_%s/%s' % (metric_name, subtask)] = np.empty_like(all_gen_images[0]) - results['eval_%s/%s' % (metric_name, subtask)] = np.empty_like(all_metric[0]) - - for i in range(args.batch_size): - for metric_name, all_metric in zip(metric_names, all_metrics): - ordered = np.argsort(np.mean(all_metric, axis=-1)[:, i]) # mean over time and sort over samples - for subtask in args.eval_substasks: - if subtask == 'max': - sidx = ordered[-1] - elif subtask == 'min': - sidx = ordered[0] - else: - raise NotImplementedError - results['eval_gen_images_%s/%s' % (metric_name, subtask)][i] = all_gen_images[sidx][i] - results['eval_%s/%s' % (metric_name, subtask)][i] = all_metric[sidx][i] - save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'), - results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) - - if 'prediction' in tasks or 'motion' in tasks: # do these together - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - fetches = {'images': model.inputs['images'], - 'gen_images': model.outputs['gen_images']} - if 'motion' in tasks: - fetches.update({'pix_distribs': model.inputs['pix_distribs'], - 'gen_pix_distribs': model.outputs['gen_pix_distribs']}) - - if args.num_stochastic_samples: - all_results = [sess.run(fetches, feed_dict=feed_dict) for _ in range(args.num_stochastic_samples)] - all_results = nest.map_structure(lambda *x: np.stack(x), *all_results) - all_context_images, all_images = np.split(all_results['images'], [context_frames], axis=2) - all_gen_images = all_results['gen_images'][:, :, context_frames - sequence_length:] - all_mse = metrics.mean_squared_error_np(all_images, all_gen_images, keep_axis=(0, 1)) - all_mse_argsort = np.argsort(all_mse, axis=0) - - for subtask, argsort_ind in zip(['_best', '_median', '_worst'], - [0, args.num_stochastic_samples // 2, -1]): - all_mse_inds = all_mse_argsort[argsort_ind] - gather = lambda x: np.array([x[ind, sample_ind] for sample_ind, ind in enumerate(all_mse_inds)]) - results = nest.map_structure(gather, all_results) - if 'prediction' in tasks: - save_prediction_results(os.path.join(output_dir, 'prediction' + subtask), - results, model.hparams, sample_ind, args.only_metrics) - if 'motion' in tasks: - draw_center = isinstance(model, models.NonTrainableVideoPredictionModel) - save_motion_results(os.path.join(output_dir, 'motion' + subtask), - results, model.hparams, draw_center, sample_ind, args.only_metrics) - else: - results = sess.run(fetches, feed_dict=feed_dict) - if 'prediction' in tasks: - save_prediction_results(os.path.join(output_dir, 'prediction'), - results, model.hparams, sample_ind, args.only_metrics) - if 'motion' in tasks: - draw_center = isinstance(model, models.NonTrainableVideoPredictionModel) - save_motion_results(os.path.join(output_dir, 'motion'), - results, model.hparams, draw_center, sample_ind, args.only_metrics) - - if 'servo' in tasks: - images = input_results['images'] - states = input_results['states'] - gen_actions = [] - gen_images = [] - for images_, states_ in zip(images, states): - obs = {'context_images': images_[:context_frames], - 'context_state': states_[0], - 'goal_image': images_[-1]} - if isinstance(servo_model, models.GroundTruthVideoPredictionModel): - obs['context_images'] = images_ - gen_actions_, gen_images_ = servo_policy.act(obs, servo_model.outputs['gen_images']) - gen_actions.append(gen_actions_) - gen_images.append(gen_images_) - gen_actions = np.stack(gen_actions) - gen_images = np.stack(gen_images) - results = {'images': input_results['images'], - 'actions': input_results['actions'], - 'goal_image': input_results['images'][:, -1], - 'gen_actions': gen_actions, - 'gen_images': gen_images} - save_servo_results(os.path.join(output_dir, 'servo'), - results, servo_model.hparams, sample_ind, args.only_metrics) - + feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} + # compute "best" metrics using the computation graph + fetches = {'images': model.inputs['images']} + fetches.update(model.eval_outputs.items()) + fetches.update(model.eval_metrics.items()) + results = sess.run(fetches, feed_dict=feed_dict) + save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'), + results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) sample_ind += args.batch_size metric_fnames = [] - if 'prediction_eval' in tasks: - metric_names = ['psnr', 'ssim', 'ssim_finn', 'vgg_csim'] - subtasks = ['max'] - for metric_name in metric_names: - for subtask in subtasks: - metric_fnames.append( - os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) - if 'prediction' in tasks: - subtask = '_best' if args.num_stochastic_samples else '' - metric_fnames.extend([ - os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'psnr'), - os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'mse'), - os.path.join(output_dir, 'prediction' + subtask, 'metrics', 'ssim'), - ]) - if 'motion' in tasks: - subtask = '_best' if args.num_stochastic_samples else '' - metric_fnames.append(os.path.join(output_dir, 'motion' + subtask, 'metrics', 'pix_dist')) - if 'servo' in tasks: - metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'goal_image_mse')) - metric_fnames.append(os.path.join(output_dir, 'servo', 'metrics', 'action_mse')) + metric_names = ['psnr', 'ssim', 'lpips'] + subtasks = ['max'] + for metric_name in metric_names: + for subtask in subtasks: + metric_fnames.append( + os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) for metric_fname in metric_fnames: task_name, _, metric_name = metric_fname.split('/')[-3:] diff --git a/scripts/generate.py b/scripts/generate.py index bfc5e89eb8651895aacca3d96be8630ab4ee1517..55eab9d221553d65190f7ba25d344b86171c6fa2 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -61,10 +61,10 @@ def main(): model_hparams_dict = {} if args.checkpoint: checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) if not os.path.isdir(args.checkpoint): checkpoint_dir, _ = os.path.split(checkpoint_dir) + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) with open(os.path.join(checkpoint_dir, "options.json")) as f: print("loading options from checkpoint %s" % args.checkpoint) options = json.loads(f.read()) @@ -78,7 +78,6 @@ def main(): try: with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: model_hparams_dict = json.loads(f.read()) - model_hparams_dict.pop('num_gpus', None) # backwards-compatibility except FileNotFoundError: print("model_hparams.json was not loaded because it does not exist") args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) @@ -97,18 +96,24 @@ def main(): print('------------------------------------- End --------------------------------------') VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset(args.input_dir, mode=args.mode, num_epochs=args.num_epochs, seed=args.seed, - hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams) - - def override_hparams_dict(dataset): - hparams_dict = dict(model_hparams_dict) - hparams_dict['context_frames'] = dataset.hparams.context_frames - hparams_dict['sequence_length'] = dataset.hparams.sequence_length - hparams_dict['repeat'] = dataset.hparams.time_shift - return hparams_dict + dataset = VideoDataset( + args.input_dir, + mode=args.mode, + num_epochs=args.num_epochs, + seed=args.seed, + hparams_dict=dataset_hparams_dict, + hparams=args.dataset_hparams) VideoPredictionModel = models.get_model_class(args.model) - model = VideoPredictionModel(mode='test', hparams_dict=override_hparams_dict(dataset), hparams=args.model_hparams) + hparams_dict = dict(model_hparams_dict) + hparams_dict.update({ + 'context_frames': dataset.hparams.context_frames, + 'sequence_length': dataset.hparams.sequence_length, + 'repeat': dataset.hparams.time_shift, + }) + model = VideoPredictionModel( + hparams_dict=hparams_dict, + hparams=args.model_hparams) if args.num_samples: if args.num_samples > dataset.num_examples_per_epoch(): @@ -117,17 +122,10 @@ def main(): else: num_examples_per_epoch = dataset.num_examples_per_epoch() if num_examples_per_epoch % args.batch_size != 0: - raise ValueError('batch_size should evenly divide the dataset') - - inputs, _ = dataset.make_batch(args.batch_size) - if not isinstance(model, models.GroundTruthVideoPredictionModel): - # remove ground truth data past context_frames to prevent accidentally using it - for k, v in inputs.items(): - if k != 'actions': - inputs[k] = v[:, :model.hparams.context_frames] + raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) + inputs = dataset.make_batch(args.batch_size) input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} - with tf.variable_scope(''): model.build_graph(input_phs) @@ -144,6 +142,7 @@ def main(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) sess = tf.Session(config=config) + sess.graph.as_default() model.restore(sess, args.checkpoint) @@ -170,7 +169,10 @@ def main(): gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) for t, gen_image in enumerate(gen_images_): gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) - gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) + if gen_image.shape[-1] == 1: + gen_image = np.tile(gen_image, (1, 1, 3)) + else: + gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) sample_ind += args.batch_size