Skip to content
Snippets Groups Projects
Commit d7a84a40 authored by gong1's avatar gong1
Browse files

Merge branch 'bing_issue#078_allow_empty_checkpoint' into develop

parents 92d366e9 829372f7
Branches
Tags
No related merge requests found
Pipeline #68379 passed
......@@ -50,8 +50,5 @@ dataset=era5
# run training
srun python ../main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \
--dataset ${dataset} --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}
--dataset ${dataset} --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir} --checkpoint_dir ${destination_dir}
......@@ -259,7 +259,7 @@ class TrainModel(object):
"""
Restore the train and validation losses in the pickle file if checkpoint is given
"""
if self.start_step == 0:
if not os.path.exists(os.path.join(self.output_dir,"checkpoint")):
train_losses = []
val_losses = []
else:
......@@ -278,6 +278,7 @@ class TrainModel(object):
print("parameter_count =", sess.run(self.parameter_count))
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
if os.path.exists(os.path.join(self.output_dir,"checkpoint")):
self.restore(sess, self.checkpoint)
#sess.graph.finalize()
self.start_step = sess.run(self.global_step)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment