diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 46874ab863aa5c2c5ac304b0df6c3986d8490fc5..5616b8b1ca4da6e8b15e9338a31831fd061f69b4 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -30,7 +30,7 @@ import math
 class TrainModel(object):
     def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
                  model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
-                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float = None, prob_save_model:float = None):
+                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None):
         """
         Class instance for training the models
         :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
@@ -265,7 +265,7 @@ class TrainModel(object):
         self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start))
         self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model))
         print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step))
-        
+
 
     def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
@@ -594,6 +594,8 @@ def main():
     parser.add_argument("--model", type=str, help="Model class name")
     parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
     parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
+    parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint")
+    parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model")
     parser.add_argument("--seed", default=1234, type=int)
 
     args = parser.parse_args()
@@ -601,8 +603,9 @@ def main():
     timeit_start_total_time = time.time()  
     #create a training instance
     train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict,
-                 model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint,dataset=args.dataset,
-                 gpu_mem_frac=args.gpu_mem_frac,seed=args.seed,args=args)  
+                 model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset,
+                 gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start,
+                 prob_save_model=args.prob_save_model)
     
     print('----------------------------------- Options ------------------------------------')
     for k, v in args._get_kwargs():