From 158cfdaf961d9f297b057c27604d59829bc14057 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 25 Feb 2022 11:39:19 +0100
Subject: [PATCH] communicate super requirements of models properly

---
 mlair/model_modules/abstract_model_class.py    | 16 ++++++++++++++--
 mlair/model_modules/branched_input_networks.py |  2 +-
 2 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py
index 4a323f46..d8e27510 100644
--- a/mlair/model_modules/abstract_model_class.py
+++ b/mlair/model_modules/abstract_model_class.py
@@ -253,5 +253,17 @@ class AbstractModelClass(ABC):
     def own_args(cls, *args):
         """Return all arguments (including kwonlyargs)."""
         arg_spec = inspect.getfullargspec(cls)
-        list_of_args = arg_spec.args + arg_spec.kwonlyargs
-        return remove_items(list_of_args, ["self"] + list(args))
+        list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args()
+        return list(set(remove_items(list_of_args, ["self"] + list(args))))
+
+    @classmethod
+    def super_args(cls):
+        args = []
+        for super_cls in cls.__mro__:
+            if super_cls == cls:
+                continue
+            if hasattr(super_cls, "own_args"):
+                # args.extend(super_cls.own_args())
+                args.extend(getattr(super_cls, "own_args")())
+        return list(set(args))
+
diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py
index 4345d48b..8da3b17e 100644
--- a/mlair/model_modules/branched_input_networks.py
+++ b/mlair/model_modules/branched_input_networks.py
@@ -290,4 +290,4 @@ class BranchedInputFCN(AbstractModelClass):  # pragma: no cover
         self.compile_options = {"loss": [keras.losses.mean_squared_error],
                                 "metrics": ["mse", "mae", var_loss]}
         # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])],
-        #                         "metrics": ["mse", "mae", var_loss]}
\ No newline at end of file
+        #                         "metrics": ["mse", "mae", var_loss]}
-- 
GitLab