diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 0b91b18174ca3416f4be2d9f96019bb8f646767e..47210f8cd69bd77d389232922361d4cb820aea5e 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -24,12 +24,13 @@ import matplotlib.pyplot as plt import pickle as pkl from model_modules.video_prediction.utils import tf_utils from general_utils import * +import math class TrainModel(object): def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, - gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01): + gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float = None, prob_save_model:float = None): """ Class instance for training the models :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located @@ -46,6 +47,8 @@ class TrainModel(object): steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results into a diagnozing intreval of 10 iteration steps (= interval over which validation loss is averaged to identify best model performance) + :param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints + :param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0) """ self.input_dir = os.path.normpath(input_dir) self.output_dir = os.path.normpath(output_dir) @@ -58,6 +61,8 @@ class TrainModel(object): self.seed = seed self.args = args self.diag_intv_frac = diag_intv_frac + self.frac_save_model_start = frac_save_model_start + self.prob_save_model = prob_save_model # for diagnozing and saving the model during training self.saver_loss = None # set in create_fetches_for_train-method self.saver_loss_name = None # set in create_fetches_for_train-method @@ -77,6 +82,8 @@ class TrainModel(object): self.create_saver_and_writer() self.setup_gpu_config() self.calculate_samples_and_epochs() + self.calculate_checkpoint_saver_conf() + def set_seed(self): """ @@ -240,6 +247,25 @@ class TrainModel(object): print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) + + def calculate_checkpoint_saver_conf(self): + """ + Calculate the start step for saving the checkpoint, and the frequences steps to save model + + """ + method = TrainModel.calculate_checkpoint_saver_conf.__name__ + + if hasattr(self.total_steps, "attr_name"): + raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method)) + if self.prob_save_model > 1 or self.prob_save_model<0 : + raise ValueError("pro_save_model should be less than 1 and larger than 0") + if self.frac_save_model_start > 1 or self.frac_save_model_start<0: + raise ValueError("frac_save_model_start should be less than 1 and larger than 0") + + self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start)) + self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model)) + + def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ Restore the models checkpoints if the checkpoints is given @@ -283,6 +309,22 @@ class TrainModel(object): val_losses = pkl.load(f) return train_losses,val_losses + def create_checkpoints_folder(self, step:int=None): + """ + Create a folder to store checkpoint at certain step + :param step: the step you want to save the checkpoint + return : dir path to save model + """ + dir_name = "checkpoint_" + str(step) + full_dir_name = os.path.join(self.output_dir,dir_name) + if os.path.isfile(os.path.join(full_dir_name,"checkpoints")): + print("The checkpoint at step {} exists".format(step)) + else: + os.mkdir(full_dir_name) + return full_dir_name + + + def train_model(self): """ Start session and train the model by looping over all iteration steps @@ -303,7 +345,7 @@ class TrainModel(object): # initialize auxiliary variables time_per_iteration = [] run_start_time = time.time() - val_loss_min = 999. + # perform iteration for step in range(start_step, self.total_steps): timeit_start = time.time() @@ -324,15 +366,15 @@ class TrainModel(object): time_iter = time.time() - timeit_start time_per_iteration.append(time_iter) print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) - if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1): - lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step) - # save best and final model state - if lsave or step == self.total_steps - 1: - self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"), - global_step=step) - # pickle file and plots are always created - TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) - TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) + + if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1): + #create a checkpoint folder for step + full_dir_name = self.create_checkpoints_folder(step=step) + self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step) + + # pickle file and plots are always created + TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) + TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) # Final diagnostics # track time (save to pickle-files)