Skip to content
Snippets Groups Projects
Commit abf6e869 authored by gong1's avatar gong1
Browse files

restore checkpoint path bug fixed

parent f81f5c92
No related branches found
No related tags found
No related merge requests found
Pipeline #69887 passed
......@@ -259,7 +259,10 @@ class TrainModel(object):
"""
Restore the train and validation losses in the pickle file if checkpoint is given
"""
if not os.path.exists(os.path.join(self.output_dir,"checkpoint")):
if (not os.path.exists(os.path.join(self.output_dir,"checkpoint"))) and (self.checkpoint is not None):
raise FileNotFoundError("The user-configured checkpoint path {} does not exists!!".format(self.checkpoint))
if (not os.path.exists(os.path.join(self.output_dir,"checkpoint"))) or (self.checkpoint is None):
train_losses = []
val_losses = []
else:
......@@ -294,15 +297,11 @@ class TrainModel(object):
self.create_fetches_for_train() # In addition to the loss, we fetch the optimizer
self.results = sess.run(self.fetches) # ...and run it here!
train_losses.append(self.results["total_loss"])
print("t_start for training",self.results["inputs"]["T_start"])
print("len of t_start per iteration",len(self.results["inputs"]["T_start"]))
#Run and fetch losses for validation data
val_handle_eval = sess.run(self.val_handle)
self.create_fetches_for_val()
self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval})
val_losses.append(self.val_results["total_loss"])
print("t_start for validation",self.val_results["inputs"]["T_start"])
print("len of t_start per iteration",len(self.val_results["inputs"]["T_start"]))
self.write_to_summary()
self.print_results(step,self.results)
timeit_end = time.time()
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment