From 39570d54ad420a1db82225c9a496513e574b8be2 Mon Sep 17 00:00:00 2001
From: Alex Lee <alexleegk@gmail.com>
Date: Wed, 23 Jan 2019 11:51:14 -0800
Subject: [PATCH] Minor tf_utils.

---
 video_prediction/utils/tf_utils.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/video_prediction/utils/tf_utils.py b/video_prediction/utils/tf_utils.py
index 979d95bd..04cb23bf 100644
--- a/video_prediction/utils/tf_utils.py
+++ b/video_prediction/utils/tf_utils.py
@@ -255,6 +255,8 @@ def add_summaries(outputs, collections=None):
     image_outputs = OrderedDict()
     gif_outputs = OrderedDict()
     for name, output in outputs.items():
+        if not isinstance(output, tf.Tensor):
+            continue
         if output.shape.ndims == 0:
             scalar_outputs[name] = output
         elif output.shape.ndims == 4:
@@ -529,7 +531,7 @@ def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=Fal
         checkpoint = tf.train.latest_checkpoint(checkpoint)
     checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
     checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
-    restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name: name.split(':')[0])
+    restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
     if not var_list:
         var_list = tf.global_variables()
     restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list}
-- 
GitLab