Skip to content
Snippets Groups Projects
Commit 7424bcc4 authored by gong1's avatar gong1
Browse files

check if checkpoint exists

parent bd000610
No related branches found
No related tags found
No related merge requests found
...@@ -259,7 +259,7 @@ class TrainModel(object): ...@@ -259,7 +259,7 @@ class TrainModel(object):
""" """
Restore the train and validation losses in the pickle file if checkpoint is given Restore the train and validation losses in the pickle file if checkpoint is given
""" """
if self.start_step == 0: if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
train_losses = [] train_losses = []
val_losses = [] val_losses = []
else: else:
...@@ -278,6 +278,7 @@ class TrainModel(object): ...@@ -278,6 +278,7 @@ class TrainModel(object):
print("parameter_count =", sess.run(self.parameter_count)) print("parameter_count =", sess.run(self.parameter_count))
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.local_variables_initializer())
if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
self.restore(sess, self.checkpoint) self.restore(sess, self.checkpoint)
#sess.graph.finalize() #sess.graph.finalize()
self.start_step = sess.run(self.global_step) self.start_step = sess.run(self.global_step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment