diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index ff9486cd29cf928f49cd8d161af28eca56380dca..d54393473313e0604028db278d652b18fbd00d4b 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -159,12 +159,12 @@ class TrainModel(object):
         # create dataset instance
         VideoDataset = datasets.get_dataset_class(self.dataset)
 
-        self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict,
+        self.train_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='train', datasplit_config=self.datasplit_dict,
                                           hparams_dict_config=self.model_hparams_dict)
-        self.calculate_samples_and_epochs()
+        self.num_examples = self.calculate_samples_and_epochs()
         self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
         # set-up validation dataset and calculate number of batches for calculating validation loss
-        self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict,
+        self.val_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='val', datasplit_config=self.datasplit_dict,
                                         hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples)
         # Retrieve sequence length from dataset
         self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
@@ -175,7 +175,7 @@ class TrainModel(object):
         :param mode: "train" used the model graph in train process;  "test" for postprocessing step
         """
         VideoPredictionModel = models.get_model_class(self.model)
-        self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode)
+        self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load, mode=mode)
 
     def setup_graph(self):
         """
@@ -201,8 +201,8 @@ class TrainModel(object):
         self.inputs = self.iterator.get_next()
         # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP
         # Otherwise an error will be risen by SAVP 
-        if self.dataset == "era5" and self.model == "savp":
-            del self.inputs["T_start"]
+        if self.dataset == "era5":
+            self.inputs = self.inputs
 
     def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model):
         """
diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
index 7393911145db16798d3a8acdf3ff30870a32e0ad..0c527b2b0f403714bfc4020e2778d569abb6ef8d 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
@@ -14,10 +14,8 @@ import tensorflow as tf
 
 class BaseModels(ABC):
 
-    def __init__(self, hparams_dict_config=None):
-        self.hparams_dict_config = hparams_dict_config
-        self.hparams_dict = self.get_model_hparams_dict()
-        self.hparams = self.parse_hparams()
+    def __init__(self, hparams_dict_config:dict = None):
+        self.hparams = self.parse_hparams(hparams_dict_config)
         # Attributes set during runtime
         self.total_loss = None
         self.loss_summary = None
@@ -33,24 +31,12 @@ class BaseModels(ABC):
         self.x_hat_predict_frames = None
 
 
-    def get_model_hparams_dict(self):
-        """
-        Get model_hparams_dict from json file
-        """
-        if self.hparams_dict_config:
-            with open(self.hparams_dict_config, 'r') as f:
-                hparams_dict = json.loads(f.read())
-        else:
-            raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.hparams_dict_config))
 
-        return hparams_dict
-
-    def parse_hparams(self):
+    def parse_hparams(self,hparams_dict_config):
         """
         Obtain the parameters from directory
         """
-
-        hparams = dotdict(self.hparams_dict)
+        hparams = dotdict(hparams_dict_config)
         return hparams
 
     @abstractmethod