Skip to content
Snippets Groups Projects
Commit aee36dd3 authored by Michael Langguth's avatar Michael Langguth
Browse files

Add print-statement to identify checkpoint that is restored.

parent 7a54fa1a
No related branches found
No related tags found
No related merge requests found
Pipeline #75864 passed
...@@ -526,10 +526,14 @@ def reduce_tensors(structures, shallow=False): ...@@ -526,10 +526,14 @@ def reduce_tensors(structures, shallow=False):
def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None): def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None):
method = get_checkpoint_restore_saver.__name__
if os.path.isdir(checkpoint): if os.path.isdir(checkpoint):
# latest_checkpoint doesn't work when the path has special characters # latest_checkpoint doesn't work when the path has special characters
checkpoint = tf.train.latest_checkpoint(checkpoint) checkpoint = tf.train.latest_checkpoint(checkpoint)
# print name of checkpoint-file for verbosity
print("%{0}: The follwoing checkpoint is used for restoring the model: '{1}'".format(method, checkpoint))
# Start processing the checkpoint
checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint) checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys() checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0]) restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment