diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index 79e5e2fdf1fd2cb998399641a258a377fc3e2bce..f623636ac38248c1f46ec912258a1890cb4d1b3f 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -180,7 +180,9 @@ class ERA5Dataset(object): if shuffle: random.shuffle(filenames) dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) - + + # cache dataset + dataset = dataset.cache() # dataset = dataset.filter(self.filter) if shuffle: dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))