Skip to content
Snippets Groups Projects
Commit 7424bcc4 authored by gong1's avatar gong1
Browse files

check if checkpoint exists

parent bd000610
Branches
Tags
No related merge requests found
...@@ -259,7 +259,7 @@ class TrainModel(object): ...@@ -259,7 +259,7 @@ class TrainModel(object):
""" """
Restore the train and validation losses in the pickle file if checkpoint is given Restore the train and validation losses in the pickle file if checkpoint is given
""" """
if self.start_step == 0: if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
train_losses = [] train_losses = []
val_losses = [] val_losses = []
else: else:
...@@ -278,6 +278,7 @@ class TrainModel(object): ...@@ -278,6 +278,7 @@ class TrainModel(object):
print("parameter_count =", sess.run(self.parameter_count)) print("parameter_count =", sess.run(self.parameter_count))
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer())
if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
self.restore(sess, self.checkpoint) self.restore(sess, self.checkpoint)
#sess.graph.finalize() #sess.graph.finalize()
self.start_step = sess.run(self.global_step) self.start_step = sess.run(self.global_step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment