diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py index c24085ffae1939b8bc95eae0e500ef2d07f917ad..2e2892d8b67e999f3d556403883c905d74a86392 100644 --- a/src/model_modules/advanced_paddings.py +++ b/src/model_modules/advanced_paddings.py @@ -118,9 +118,9 @@ class Padding2D: def __init__(self, padding_type): self.padding_type = padding_type self.allowed_paddings = { - 'RefPad2D': ReflectionPadding2D, 'ReflectionPadding2D': ReflectionPadding2D, - 'SymPad2D': SymmetricPadding2D, 'SymmetricPadding2D': SymmetricPadding2D, - 'ZeroPad2D': ZeroPadding2D, 'ZeroPadding2D': ZeroPadding2D + **dict.fromkeys(("RefPad2D", "ReflectionPadding2D"), ReflectionPadding2D), + **dict.fromkeys(("SymPad2D", "SymmetricPadding2D"), SymmetricPadding2D), + **dict.fromkeys(("ZeroPad2D", "ZeroPadding2D"), ZeroPadding2D) } def _check_and_get_padding(self): @@ -295,8 +295,8 @@ if __name__ == '__main__': kernel_1 = (3, 3) kernel_2 = (5, 5) - kernel_3 = (3,3) - x = np.array(range(2000)).reshape(-1, 10, 10, 1) + kernel_3 = (3, 3) + x = np.array(range(2000)).reshape((-1, 10, 10, 1)) y = x.mean(axis=(1, 2)) x_input = Input(shape=x.shape[1:])