diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index 79e5e2fdf1fd2cb998399641a258a377fc3e2bce..f623636ac38248c1f46ec912258a1890cb4d1b3f 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -180,7 +180,9 @@ class ERA5Dataset(object):
         if shuffle:
             random.shuffle(filenames)
         dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) 
-        
+       
+        # cache dataset
+        dataset = dataset.cache()  
         # dataset = dataset.filter(self.filter)
         if shuffle:
             dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))