From 158cfdaf961d9f297b057c27604d59829bc14057 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 25 Feb 2022 11:39:19 +0100 Subject: [PATCH] communicate super requirements of models properly --- mlair/model_modules/abstract_model_class.py | 16 ++++++++++++++-- mlair/model_modules/branched_input_networks.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 4a323f46..d8e27510 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -253,5 +253,17 @@ class AbstractModelClass(ABC): def own_args(cls, *args): """Return all arguments (including kwonlyargs).""" arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs - return remove_items(list_of_args, ["self"] + list(args)) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() + return list(set(remove_items(list_of_args, ["self"] + list(args)))) + + @classmethod + def super_args(cls): + args = [] + for super_cls in cls.__mro__: + if super_cls == cls: + continue + if hasattr(super_cls, "own_args"): + # args.extend(super_cls.own_args()) + args.extend(getattr(super_cls, "own_args")()) + return list(set(args)) + diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 4345d48b..8da3b17e 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -290,4 +290,4 @@ class BranchedInputFCN(AbstractModelClass): # pragma: no cover self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae", var_loss]} # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])], - # "metrics": ["mse", "mae", var_loss]} \ No newline at end of file + # "metrics": ["mse", "mae", var_loss]} -- GitLab