diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 46874ab863aa5c2c5ac304b0df6c3986d8490fc5..5616b8b1ca4da6e8b15e9338a31831fd061f69b4 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -30,7 +30,7 @@ import math class TrainModel(object): def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, - gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float = None, prob_save_model:float = None): + gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None): """ Class instance for training the models :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located @@ -265,7 +265,7 @@ class TrainModel(object): self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start)) self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model)) print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step)) - + def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ @@ -594,6 +594,8 @@ def main(): parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") + parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint") + parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model") parser.add_argument("--seed", default=1234, type=int) args = parser.parse_args() @@ -601,8 +603,9 @@ def main(): timeit_start_total_time = time.time() #create a training instance train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, - model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint,dataset=args.dataset, - gpu_mem_frac=args.gpu_mem_frac,seed=args.seed,args=args) + model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset, + gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start, + prob_save_model=args.prob_save_model) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs():