diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 0b91b18174ca3416f4be2d9f96019bb8f646767e..a54abb14158c4466a9f5783c93f0c09f78894118 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -24,12 +24,13 @@ import matplotlib.pyplot as plt import pickle as pkl from model_modules.video_prediction.utils import tf_utils from general_utils import * +import math class TrainModel(object): def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, - gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01): + gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None): """ Class instance for training the models :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located @@ -46,6 +47,8 @@ class TrainModel(object): steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results into a diagnozing intreval of 10 iteration steps (= interval over which validation loss is averaged to identify best model performance) + :param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints + :param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0) """ self.input_dir = os.path.normpath(input_dir) self.output_dir = os.path.normpath(output_dir) @@ -58,6 +61,8 @@ class TrainModel(object): self.seed = seed self.args = args self.diag_intv_frac = diag_intv_frac + self.frac_save_model_start = frac_save_model_start + self.prob_save_model = prob_save_model # for diagnozing and saving the model during training self.saver_loss = None # set in create_fetches_for_train-method self.saver_loss_name = None # set in create_fetches_for_train-method @@ -77,6 +82,8 @@ class TrainModel(object): self.create_saver_and_writer() self.setup_gpu_config() self.calculate_samples_and_epochs() + self.calculate_checkpoint_saver_conf() + def set_seed(self): """ @@ -214,7 +221,7 @@ class TrainModel(object): """ Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics """ - self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=2) + self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=None) self.summary_writer = tf.summary.FileWriter(self.output_dir) def setup_gpu_config(self): @@ -240,6 +247,26 @@ class TrainModel(object): print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) + + def calculate_checkpoint_saver_conf(self): + """ + Calculate the start step for saving the checkpoint, and the frequences steps to save model + + """ + method = TrainModel.calculate_checkpoint_saver_conf.__name__ + + if hasattr(self.total_steps, "attr_name"): + raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method)) + if self.prob_save_model > 1 or self.prob_save_model<0 : + raise ValueError("pro_save_model should be less than 1 and larger than 0") + if self.frac_save_model_start > 1 or self.frac_save_model_start<0: + raise ValueError("frac_save_model_start should be less than 1 and larger than 0") + + self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start)) + self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model)) + print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step)) + + def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ Restore the models checkpoints if the checkpoints is given @@ -283,6 +310,22 @@ class TrainModel(object): val_losses = pkl.load(f) return train_losses,val_losses + def create_checkpoints_folder(self, step:int=None): + """ + Create a folder to store checkpoint at certain step + :param step: the step you want to save the checkpoint + return : dir path to save model + """ + dir_name = "checkpoint_" + str(step) + full_dir_name = os.path.join(self.output_dir,dir_name) + if os.path.isfile(os.path.join(full_dir_name,"checkpoints")): + print("The checkpoint at step {} exists".format(step)) + else: + os.mkdir(full_dir_name) + return full_dir_name + + + def train_model(self): """ Start session and train the model by looping over all iteration steps @@ -303,7 +346,7 @@ class TrainModel(object): # initialize auxiliary variables time_per_iteration = [] run_start_time = time.time() - val_loss_min = 999. + # perform iteration for step in range(start_step, self.total_steps): timeit_start = time.time() @@ -324,15 +367,15 @@ class TrainModel(object): time_iter = time.time() - timeit_start time_per_iteration.append(time_iter) print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) - if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1): - lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step) - # save best and final model state - if lsave or step == self.total_steps - 1: - self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"), - global_step=step) - # pickle file and plots are always created - TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) - TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) + + if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1): + #create a checkpoint folder for step + full_dir_name = self.create_checkpoints_folder(step=step) + self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step) + + # pickle file and plots are always created + TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) + TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) # Final diagnostics # track time (save to pickle-files) @@ -551,6 +594,8 @@ def main(): parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") + parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint") + parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model") parser.add_argument("--seed", default=1234, type=int) args = parser.parse_args() @@ -558,8 +603,9 @@ def main(): timeit_start_total_time = time.time() #create a training instance train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, - model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint,dataset=args.dataset, - gpu_mem_frac=args.gpu_mem_frac,seed=args.seed,args=args) + model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset, + gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start, + prob_save_model=args.prob_save_model) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs():