diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index ce62965a2c92432ffbf739e933279f91b69e355c..664a817a95578eb561834a5e36658e0981fd5686 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -10,7 +10,7 @@ import tensorflow as tf
 from collections import OrderedDict
 from tensorflow.contrib.training import HParams
 from google.protobuf.json_format import MessageToDict
-
+from general_utils import reduce_dict
 
 class ERA5Dataset(object):
 
@@ -58,20 +58,18 @@ class ERA5Dataset(object):
 
     def get_default_hparams_dict(self):
         """
-        The function that contains default hparams
+        Provide dictionary containing default hyperparameters for the dataset
         Returns:
-            A dict with the following hyperparameters.
-            context_frames  : the number of ground-truth frames to pass in at start.
+            A dict with the following hyperparameters relevant for the dataset.
+            context_frames  : the number of ground-truth frames to pass in at start
+            batch_size      : number of training examples per mini-batch
             max_epochs      : the number of epochs to train model
-            lr              : learning rate
             loss_fun        : the loss function
         """
         hparams = dict(
             context_frames=10,
             max_epochs = 20,
             batch_size = 40,
-            lr = 0.001,
-            loss_fun = "rmse",
             shuffle_on_val= True,
         )
         return hparams
@@ -89,7 +87,9 @@ class ERA5Dataset(object):
         """
         Parse the hparams setting to ovoerride the default ones
         """
+        self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values())
         parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+
         return parsed_hparams