Skip to content
Snippets Groups Projects
Commit 655aa68e authored by Michael Langguth's avatar Michael Langguth
Browse files

Merge branch 'michael_issue#110_parse_hyperparameters_to_dataset_instance'...

Merge branch 'michael_issue#110_parse_hyperparameters_to_dataset_instance' into develop. This step was repeated after cherry-picking from michael_issue#116_extract_best_model in order solve parsing errors due to parameters of used dataset.
parents 4416932e d47c390d
No related branches found
No related tags found
No related merge requests found
Pipeline #70893 passed
...@@ -10,7 +10,7 @@ import tensorflow as tf ...@@ -10,7 +10,7 @@ import tensorflow as tf
from collections import OrderedDict from collections import OrderedDict
from tensorflow.contrib.training import HParams from tensorflow.contrib.training import HParams
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
from general_utils import reduce_dict
class ERA5Dataset(object): class ERA5Dataset(object):
...@@ -58,20 +58,18 @@ class ERA5Dataset(object): ...@@ -58,20 +58,18 @@ class ERA5Dataset(object):
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
""" """
The function that contains default hparams Provide dictionary containing default hyperparameters for the dataset
Returns: Returns:
A dict with the following hyperparameters. A dict with the following hyperparameters relevant for the dataset.
context_frames : the number of ground-truth frames to pass in at start. context_frames : the number of ground-truth frames to pass in at start
batch_size : number of training examples per mini-batch
max_epochs : the number of epochs to train model max_epochs : the number of epochs to train model
lr : learning rate
loss_fun : the loss function loss_fun : the loss function
""" """
hparams = dict( hparams = dict(
context_frames=10, context_frames=10,
max_epochs = 20, max_epochs = 20,
batch_size = 40, batch_size = 40,
lr = 0.001,
loss_fun = "rmse",
shuffle_on_val= True, shuffle_on_val= True,
) )
return hparams return hparams
...@@ -89,7 +87,9 @@ class ERA5Dataset(object): ...@@ -89,7 +87,9 @@ class ERA5Dataset(object):
""" """
Parse the hparams setting to ovoerride the default ones Parse the hparams setting to ovoerride the default ones
""" """
self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values())
parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
return parsed_hparams return parsed_hparams
......
...@@ -6,6 +6,7 @@ Provides: * get_unique_vars ...@@ -6,6 +6,7 @@ Provides: * get_unique_vars
* isw * isw
* check_str_in_list * check_str_in_list
* check_dir * check_dir
* reduce_dict
* provide_default * provide_default
* get_era5_atts * get_era5_atts
""" """
...@@ -152,6 +153,28 @@ def check_dir(path2dir: str, lcreate=False): ...@@ -152,6 +153,28 @@ def check_dir(path2dir: str, lcreate=False):
raise NotADirectoryError("%{0}: Directory '{1}' does not exist".format(method, path2dir)) raise NotADirectoryError("%{0}: Directory '{1}' does not exist".format(method, path2dir))
def reduce_dict(dict_in: dict, dict_ref: dict):
"""
Reduces input dictionary to keys from reference dictionary. If the input dictionary lacks some keys, these are
copied over from the reference dictionary, i.e. the reference dictionary provides the defaults
:param dict_in: input dictionary
:param dict_ref: reference dictionary
:return: reduced form of input dictionary (with keys complemented from dict_ref if necessary)
"""
method = reduce_dict.__name__
# sanity checks
assert isinstance(dict_in, dict), "%{0}: dict_in must be a dictionary, but is of type {1}"\
.format(method, type(dict_in))
assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\
.format(method, type(dict_ref))
dict_merged = {**dict_ref, **dict_in}
dict_reduced = {key: dict_merged[key] for key in dict_ref}
return dict_reduced
def provide_default(dict_in, keyname, default=None, required=False): def provide_default(dict_in, keyname, default=None, required=False):
""" """
Returns values of key from input dictionary or alternatively its default Returns values of key from input dictionary or alternatively its default
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment