Skip to content
Snippets Groups Projects
Commit ebded8a4 authored by BING GONG's avatar BING GONG
Browse files

Create checkpoint folders and save checkpoint at certain step

parent e0de9e0b
Branches
Tags
No related merge requests found
Pipeline #76838 passed
...@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt ...@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt
import pickle as pkl import pickle as pkl
from model_modules.video_prediction.utils import tf_utils from model_modules.video_prediction.utils import tf_utils
from general_utils import * from general_utils import *
import math
class TrainModel(object): class TrainModel(object):
def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01): gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float = None, prob_save_model:float = None):
""" """
Class instance for training the models Class instance for training the models
:param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
...@@ -46,6 +47,8 @@ class TrainModel(object): ...@@ -46,6 +47,8 @@ class TrainModel(object):
steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
into a diagnozing intreval of 10 iteration steps (= interval over which validation loss into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
is averaged to identify best model performance) is averaged to identify best model performance)
:param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints
:param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0)
""" """
self.input_dir = os.path.normpath(input_dir) self.input_dir = os.path.normpath(input_dir)
self.output_dir = os.path.normpath(output_dir) self.output_dir = os.path.normpath(output_dir)
...@@ -58,6 +61,8 @@ class TrainModel(object): ...@@ -58,6 +61,8 @@ class TrainModel(object):
self.seed = seed self.seed = seed
self.args = args self.args = args
self.diag_intv_frac = diag_intv_frac self.diag_intv_frac = diag_intv_frac
self.frac_save_model_start = frac_save_model_start
self.prob_save_model = prob_save_model
# for diagnozing and saving the model during training # for diagnozing and saving the model during training
self.saver_loss = None # set in create_fetches_for_train-method self.saver_loss = None # set in create_fetches_for_train-method
self.saver_loss_name = None # set in create_fetches_for_train-method self.saver_loss_name = None # set in create_fetches_for_train-method
...@@ -77,6 +82,8 @@ class TrainModel(object): ...@@ -77,6 +82,8 @@ class TrainModel(object):
self.create_saver_and_writer() self.create_saver_and_writer()
self.setup_gpu_config() self.setup_gpu_config()
self.calculate_samples_and_epochs() self.calculate_samples_and_epochs()
self.calculate_checkpoint_saver_conf()
def set_seed(self): def set_seed(self):
""" """
...@@ -240,6 +247,25 @@ class TrainModel(object): ...@@ -240,6 +247,25 @@ class TrainModel(object):
print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
.format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
def calculate_checkpoint_saver_conf(self):
"""
Calculate the start step for saving the checkpoint, and the frequences steps to save model
"""
method = TrainModel.calculate_checkpoint_saver_conf.__name__
if hasattr(self.total_steps, "attr_name"):
raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method))
if self.prob_save_model > 1 or self.prob_save_model<0 :
raise ValueError("pro_save_model should be less than 1 and larger than 0")
if self.frac_save_model_start > 1 or self.frac_save_model_start<0:
raise ValueError("frac_save_model_start should be less than 1 and larger than 0")
self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start))
self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model))
def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
""" """
Restore the models checkpoints if the checkpoints is given Restore the models checkpoints if the checkpoints is given
...@@ -283,6 +309,22 @@ class TrainModel(object): ...@@ -283,6 +309,22 @@ class TrainModel(object):
val_losses = pkl.load(f) val_losses = pkl.load(f)
return train_losses,val_losses return train_losses,val_losses
def create_checkpoints_folder(self, step:int=None):
"""
Create a folder to store checkpoint at certain step
:param step: the step you want to save the checkpoint
return : dir path to save model
"""
dir_name = "checkpoint_" + str(step)
full_dir_name = os.path.join(self.output_dir,dir_name)
if os.path.isfile(os.path.join(full_dir_name,"checkpoints")):
print("The checkpoint at step {} exists".format(step))
else:
os.mkdir(full_dir_name)
return full_dir_name
def train_model(self): def train_model(self):
""" """
Start session and train the model by looping over all iteration steps Start session and train the model by looping over all iteration steps
...@@ -303,7 +345,7 @@ class TrainModel(object): ...@@ -303,7 +345,7 @@ class TrainModel(object):
# initialize auxiliary variables # initialize auxiliary variables
time_per_iteration = [] time_per_iteration = []
run_start_time = time.time() run_start_time = time.time()
val_loss_min = 999.
# perform iteration # perform iteration
for step in range(start_step, self.total_steps): for step in range(start_step, self.total_steps):
timeit_start = time.time() timeit_start = time.time()
...@@ -324,12 +366,12 @@ class TrainModel(object): ...@@ -324,12 +366,12 @@ class TrainModel(object):
time_iter = time.time() - timeit_start time_iter = time.time() - timeit_start
time_per_iteration.append(time_iter) time_per_iteration.append(time_iter)
print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1):
lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step) if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1):
# save best and final model state #create a checkpoint folder for step
if lsave or step == self.total_steps - 1: full_dir_name = self.create_checkpoints_folder(step=step)
self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"), self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step)
global_step=step)
# pickle file and plots are always created # pickle file and plots are always created
TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment