diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index 5d0fb397ef4be4d9c1e8bd3aad6d80bdc1a9925b..2cf5d224d189f52b602da9fcbe24efbbd10542f8 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -152,6 +152,31 @@ def check_dir(path2dir: str, lcreate=False): raise NotADirectoryError("%{0}: Directory '{1}' does not exist".format(method, path2dir)) +def reduce_dict(dict_in: dict, dict_ref: dict): + """ + Returns reduced version of input directory with keys only that are also part in reference dictionary + :param dict_in: input dictionary + :param dict_ref: reference dictionary + :return: subset of input dictionary + """ + method = reduce_dict.__name__ + + # sanity checks + assert isinstance(dict_in, dict), "%{0}: dict_in must be a dictionary, but is of type {1}"\ + .format(method, type(dict_in)) + assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\ + .format(method, type(dict_ref)) + + if set(dict_ref.keys()).issubset(set(dict_in.keys())): + dict_in_subset = {key: dict_in[key] for key in dict_ref} + else: + print("Keys in dict_ref: {0}".format(", ".join(dict_ref.keys()))) + print("Keys in dict_in: {0}".format(", ".join(dict_in.keys()))) + raise KeyError("%{0}: Could not find all required keys from dict_ref in dict_in.".format(method)) + + return dict_in_subset + + def provide_default(dict_in, keyname, default=None, required=False): """ Returns values of key from input dictionary or alternatively its default