diff --git a/video_prediction/utils/tf_utils.py b/video_prediction/utils/tf_utils.py index 979d95bd8a0c57aa49aaeaa61aa915efbb0cdff9..04cb23bfc3b15cf2a74a1885a869079222ba403b 100644 --- a/video_prediction/utils/tf_utils.py +++ b/video_prediction/utils/tf_utils.py @@ -255,6 +255,8 @@ def add_summaries(outputs, collections=None): image_outputs = OrderedDict() gif_outputs = OrderedDict() for name, output in outputs.items(): + if not isinstance(output, tf.Tensor): + continue if output.shape.ndims == 0: scalar_outputs[name] = output elif output.shape.ndims == 4: @@ -529,7 +531,7 @@ def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=Fal checkpoint = tf.train.latest_checkpoint(checkpoint) checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint) checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys() - restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name: name.split(':')[0]) + restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0]) if not var_list: var_list = tf.global_variables() restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list}