diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py index 7a1da880defb61dbd018c6f11ee14c34cf0ce43e..415275e8f909560cab725ecae38bb37a01809244 100644 --- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py @@ -526,10 +526,14 @@ def reduce_tensors(structures, shallow=False): def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None): + method = get_checkpoint_restore_saver.__name__ if os.path.isdir(checkpoint): # latest_checkpoint doesn't work when the path has special characters checkpoint = tf.train.latest_checkpoint(checkpoint) + # print name of checkpoint-file for verbosity + print("%{0}: The follwoing checkpoint is used for restoring the model: '{1}'".format(method, checkpoint)) + # Start processing the checkpoint checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint) checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys() restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])