Skip to content
Snippets Groups Projects
Commit 39570d54 authored by Alex Lee's avatar Alex Lee
Browse files

Minor tf_utils.

parent 1fffa00b
Branches
Tags
No related merge requests found
......@@ -255,6 +255,8 @@ def add_summaries(outputs, collections=None):
image_outputs = OrderedDict()
gif_outputs = OrderedDict()
for name, output in outputs.items():
if not isinstance(output, tf.Tensor):
continue
if output.shape.ndims == 0:
scalar_outputs[name] = output
elif output.shape.ndims == 4:
......@@ -529,7 +531,7 @@ def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=Fal
checkpoint = tf.train.latest_checkpoint(checkpoint)
checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name: name.split(':')[0])
restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
if not var_list:
var_list = tf.global_variables()
restore_vars = {restore_to_checkpoint_mapping(var.name, checkpoint_var_names): var for var in var_list}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment