diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index ce62965a2c92432ffbf739e933279f91b69e355c..664a817a95578eb561834a5e36658e0981fd5686 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -10,7 +10,7 @@ import tensorflow as tf from collections import OrderedDict from tensorflow.contrib.training import HParams from google.protobuf.json_format import MessageToDict - +from general_utils import reduce_dict class ERA5Dataset(object): @@ -58,20 +58,18 @@ class ERA5Dataset(object): def get_default_hparams_dict(self): """ - The function that contains default hparams + Provide dictionary containing default hyperparameters for the dataset Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. + A dict with the following hyperparameters relevant for the dataset. + context_frames : the number of ground-truth frames to pass in at start + batch_size : number of training examples per mini-batch max_epochs : the number of epochs to train model - lr : learning rate loss_fun : the loss function """ hparams = dict( context_frames=10, max_epochs = 20, batch_size = 40, - lr = 0.001, - loss_fun = "rmse", shuffle_on_val= True, ) return hparams @@ -89,7 +87,9 @@ class ERA5Dataset(object): """ Parse the hparams setting to ovoerride the default ones """ + self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values()) parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index 5d0fb397ef4be4d9c1e8bd3aad6d80bdc1a9925b..9ab152bac87c76cc2e37341df597133e5af9089b 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -6,6 +6,7 @@ Provides: * get_unique_vars * isw * check_str_in_list * check_dir + * reduce_dict * provide_default * get_era5_atts """ @@ -152,6 +153,28 @@ def check_dir(path2dir: str, lcreate=False): raise NotADirectoryError("%{0}: Directory '{1}' does not exist".format(method, path2dir)) +def reduce_dict(dict_in: dict, dict_ref: dict): + """ + Reduces input dictionary to keys from reference dictionary. If the input dictionary lacks some keys, these are + copied over from the reference dictionary, i.e. the reference dictionary provides the defaults + :param dict_in: input dictionary + :param dict_ref: reference dictionary + :return: reduced form of input dictionary (with keys complemented from dict_ref if necessary) + """ + method = reduce_dict.__name__ + + # sanity checks + assert isinstance(dict_in, dict), "%{0}: dict_in must be a dictionary, but is of type {1}"\ + .format(method, type(dict_in)) + assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\ + .format(method, type(dict_ref)) + + dict_merged = {**dict_ref, **dict_in} + dict_reduced = {key: dict_merged[key] for key in dict_ref} + + return dict_reduced + + def provide_default(dict_in, keyname, default=None, required=False): """ Returns values of key from input dictionary or alternatively its default