diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index ce62965a2c92432ffbf739e933279f91b69e355c..664a817a95578eb561834a5e36658e0981fd5686 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -10,7 +10,7 @@ import tensorflow as tf from collections import OrderedDict from tensorflow.contrib.training import HParams from google.protobuf.json_format import MessageToDict - +from general_utils import reduce_dict class ERA5Dataset(object): @@ -58,20 +58,18 @@ class ERA5Dataset(object): def get_default_hparams_dict(self): """ - The function that contains default hparams + Provide dictionary containing default hyperparameters for the dataset Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. + A dict with the following hyperparameters relevant for the dataset. + context_frames : the number of ground-truth frames to pass in at start + batch_size : number of training examples per mini-batch max_epochs : the number of epochs to train model - lr : learning rate loss_fun : the loss function """ hparams = dict( context_frames=10, max_epochs = 20, batch_size = 40, - lr = 0.001, - loss_fun = "rmse", shuffle_on_val= True, ) return hparams @@ -89,7 +87,9 @@ class ERA5Dataset(object): """ Parse the hparams setting to ovoerride the default ones """ + self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values()) parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams