diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index 1b6b6b31777d1a0763ca2e1d8af0b9bec79b7040..9ab152bac87c76cc2e37341df597133e5af9089b 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -155,10 +155,11 @@ def check_dir(path2dir: str, lcreate=False): def reduce_dict(dict_in: dict, dict_ref: dict): """ - Returns reduced version of input directory with keys only that are also part in reference dictionary + Reduces input dictionary to keys from reference dictionary. If the input dictionary lacks some keys, these are + copied over from the reference dictionary, i.e. the reference dictionary provides the defaults :param dict_in: input dictionary :param dict_ref: reference dictionary - :return: subset of input dictionary + :return: reduced form of input dictionary (with keys complemented from dict_ref if necessary) """ method = reduce_dict.__name__ @@ -167,15 +168,11 @@ def reduce_dict(dict_in: dict, dict_ref: dict): .format(method, type(dict_in)) assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\ .format(method, type(dict_ref)) + + dict_merged = {**dict_ref, **dict_in} + dict_reduced = {key: dict_merged[key] for key in dict_ref} - if set(dict_ref.keys()).issubset(set(dict_in.keys())): - dict_in_subset = {key: dict_in[key] for key in dict_ref} - else: - print("Keys in dict_ref: {0}".format(", ".join(dict_ref.keys()))) - print("Keys in dict_in: {0}".format(", ".join(dict_in.keys()))) - raise KeyError("%{0}: Could not find all required keys from dict_ref in dict_in.".format(method)) - - return dict_in_subset + return dict_reduced def provide_default(dict_in, keyname, default=None, required=False):