diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 0b91b18174ca3416f4be2d9f96019bb8f646767e..47210f8cd69bd77d389232922361d4cb820aea5e 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt
 import pickle as pkl
 from model_modules.video_prediction.utils import tf_utils
 from general_utils import *
+import math
 
 
 class TrainModel(object):
     def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
                  model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
-                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01):
+                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float = None, prob_save_model:float = None):
         """
         Class instance for training the models
         :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
@@ -46,6 +47,8 @@ class TrainModel(object):
                                steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
                                into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
                                is averaged to identify best model performance)
+        :param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints
+        :param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0)
         """
         self.input_dir = os.path.normpath(input_dir)
         self.output_dir = os.path.normpath(output_dir)
@@ -58,6 +61,8 @@ class TrainModel(object):
         self.seed = seed
         self.args = args
         self.diag_intv_frac = diag_intv_frac
+        self.frac_save_model_start = frac_save_model_start
+        self.prob_save_model = prob_save_model
         # for diagnozing and saving the model during training
         self.saver_loss = None         # set in create_fetches_for_train-method
         self.saver_loss_name = None    # set in create_fetches_for_train-method 
@@ -77,6 +82,8 @@ class TrainModel(object):
         self.create_saver_and_writer()
         self.setup_gpu_config()
         self.calculate_samples_and_epochs()
+        self.calculate_checkpoint_saver_conf()
+
 
     def set_seed(self):
         """
@@ -240,6 +247,25 @@ class TrainModel(object):
         print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
               .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
 
+
+    def calculate_checkpoint_saver_conf(self):
+        """
+        Calculate the start step for saving the checkpoint, and the frequences steps to save model
+
+        """
+        method = TrainModel.calculate_checkpoint_saver_conf.__name__
+
+        if hasattr(self.total_steps, "attr_name"):
+            raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method))
+        if self.prob_save_model > 1 or self.prob_save_model<0 :
+            raise ValueError("pro_save_model should be less than 1 and larger than 0")
+        if self.frac_save_model_start > 1 or self.frac_save_model_start<0:
+            raise ValueError("frac_save_model_start should be less than 1 and larger than 0")
+
+        self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start))
+        self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model))
+
+
     def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
         Restore the models checkpoints if the checkpoints is given
@@ -283,6 +309,22 @@ class TrainModel(object):
                 val_losses = pkl.load(f)
         return train_losses,val_losses
 
+    def create_checkpoints_folder(self, step:int=None):
+        """
+        Create a folder to store checkpoint at certain step
+        :param step: the step you want to save the checkpoint
+        return : dir path to save model
+        """
+        dir_name = "checkpoint_" + str(step)
+        full_dir_name = os.path.join(self.output_dir,dir_name)
+        if os.path.isfile(os.path.join(full_dir_name,"checkpoints")):
+            print("The checkpoint at step {} exists".format(step))
+        else:
+            os.mkdir(full_dir_name)
+        return full_dir_name
+
+
+
     def train_model(self):
         """
         Start session and train the model by looping over all iteration steps
@@ -303,7 +345,7 @@ class TrainModel(object):
             # initialize auxiliary variables
             time_per_iteration = []
             run_start_time = time.time()
-            val_loss_min = 999.
+
             # perform iteration
             for step in range(start_step, self.total_steps):
                 timeit_start = time.time()
@@ -324,15 +366,15 @@ class TrainModel(object):
                 time_iter = time.time() - timeit_start
                 time_per_iteration.append(time_iter)
                 print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
-                if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1):
-                    lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step)
-                    # save best and final model state
-                    if lsave or step == self.total_steps - 1:
-                        self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"),
-                                        global_step=step)
-                    # pickle file and plots are always created
-                    TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
-                    TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
+
+                if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1):
+                    #create a checkpoint folder for step
+                    full_dir_name = self.create_checkpoints_folder(step=step)
+                    self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step)
+
+                # pickle file and plots are always created
+                TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
+                TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
 
             # Final diagnostics
             # track time (save to pickle-files)