diff --git a/mlair/model_modules/residual_networks.py b/mlair/model_modules/residual_networks.py index 8d39d776238341fbf2da29673910a2127d15528b..a9b502c4ef9ba5daa2b624f678b1f951dad3b747 100644 --- a/mlair/model_modules/residual_networks.py +++ b/mlair/model_modules/residual_networks.py @@ -47,7 +47,10 @@ class BranchedInputResNet(BranchedInputCNN): layer_name = layer_kwargs.pop("name").split("_") layer_name = "_".join([*layer_name[0:2], "%s", *layer_name[2:]]) act = layer_kwargs.pop("activation") - act_name = act.args[0] if isinstance(act, partial) else act.__name__ + if isinstance(act, partial): + act_name = act.args[0] if act.func.__name__ == "Activation" else act.func.__name__ + else: + act_name = act.__name__ use_1x1conv = layer_kwargs.pop("use_1x1conv", False) def block(x):