From 2bf381a444e9b481d83eb6a408be9d0b7f880982 Mon Sep 17 00:00:00 2001
From: Michael <m.langguth@fz-juelich.de>
Date: Mon, 21 Jun 2021 11:38:37 +0200
Subject: [PATCH] Revised reduce_dict-function in general_utils.py.

---
 video_prediction_tools/utils/general_utils.py | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py
index 1b6b6b31..9ab152ba 100644
--- a/video_prediction_tools/utils/general_utils.py
+++ b/video_prediction_tools/utils/general_utils.py
@@ -155,10 +155,11 @@ def check_dir(path2dir: str, lcreate=False):
 
 def reduce_dict(dict_in: dict, dict_ref: dict):
     """
-    Returns reduced version of input directory with keys only that are also part in reference dictionary
+    Reduces input dictionary to keys from reference dictionary. If the input dictionary lacks some keys, these are 
+    copied over from the reference dictionary, i.e. the reference dictionary provides the defaults
     :param dict_in: input dictionary
     :param dict_ref: reference dictionary
-    :return: subset of input dictionary
+    :return: reduced form of input dictionary (with keys complemented from dict_ref if necessary)
     """
     method = reduce_dict.__name__
 
@@ -167,15 +168,11 @@ def reduce_dict(dict_in: dict, dict_ref: dict):
                                       .format(method, type(dict_in))
     assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\
                                        .format(method, type(dict_ref))
+    
+    dict_merged = {**dict_ref, **dict_in}
+    dict_reduced = {key: dict_merged[key] for key in dict_ref}
 
-    if set(dict_ref.keys()).issubset(set(dict_in.keys())):
-        dict_in_subset = {key: dict_in[key] for key in dict_ref}
-    else:
-        print("Keys in dict_ref: {0}".format(", ".join(dict_ref.keys())))
-        print("Keys in dict_in: {0}".format(", ".join(dict_in.keys())))
-        raise KeyError("%{0}: Could not find all required keys from dict_ref in dict_in.".format(method))
-
-    return dict_in_subset
+    return dict_reduced
 
 
 def provide_default(dict_in, keyname, default=None, required=False):
-- 
GitLab