diff --git a/mlair/model_modules/abstract_model_class.py b/mlair/model_modules/abstract_model_class.py index 4a323f46ff95a7ca66c157f2e4d6d3184f244a4a..d8e275101e7ec1a2388cc52111034d2497c1e82d 100644 --- a/mlair/model_modules/abstract_model_class.py +++ b/mlair/model_modules/abstract_model_class.py @@ -253,5 +253,17 @@ class AbstractModelClass(ABC): def own_args(cls, *args): """Return all arguments (including kwonlyargs).""" arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs - return remove_items(list_of_args, ["self"] + list(args)) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() + return list(set(remove_items(list_of_args, ["self"] + list(args)))) + + @classmethod + def super_args(cls): + args = [] + for super_cls in cls.__mro__: + if super_cls == cls: + continue + if hasattr(super_cls, "own_args"): + # args.extend(super_cls.own_args()) + args.extend(getattr(super_cls, "own_args")()) + return list(set(args)) + diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 4345d48b87d9304e7de49f16cb4cc39427d3bce2..8da3b17edc47fe281738bac54d26a28836716fac 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -290,4 +290,4 @@ class BranchedInputFCN(AbstractModelClass): # pragma: no cover self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae", var_loss]} # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss], loss_weights=[2, 1])], - # "metrics": ["mse", "mae", var_loss]} \ No newline at end of file + # "metrics": ["mse", "mae", var_loss]}