diff --git a/video_prediction/layers/normalization.py b/video_prediction/layers/normalization.py index 453391e002eed200c142eb131bc5a8f76ed5d1ab..2f2a9bb9c0cf41eb7ed10f25e606c06791e0d2a3 100644 --- a/video_prediction/layers/normalization.py +++ b/video_prediction/layers/normalization.py @@ -144,26 +144,52 @@ def fused_instance_norm(inputs, gamma = array_ops.reshape(gamma, params_shape_broadcast) if data_format == DATA_FORMAT_NHWC: - inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis))) - inputs_nchw_shape = inputs.shape - inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:]) + inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis]) + if data_format == DATA_FORMAT_NCHW: + inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis]) + hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value + inputs = array_ops.reshape(inputs, [1] + hw + [n * c]) if inputs.shape.ndims != 4: # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] if inputs.shape.ndims > 4: - inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]] + inputs_ndims4_shape = [1, hw[0], -1, n * c] else: - inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1] + inputs_ndims4_shape = [1, 1, -1, n * c] inputs = array_ops.reshape(inputs, inputs_ndims4_shape) - beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) - gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) + beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1]) + gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1]) outputs, _, _ = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, - data_format=DATA_FORMAT_NCHW, name='instancenorm') + data_format=DATA_FORMAT_NHWC, name='instancenorm') - outputs = array_ops.reshape(outputs, inputs_nchw_shape) + outputs = array_ops.reshape(outputs, hw + [n, c]) if data_format == DATA_FORMAT_NHWC: - outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1]) + outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1]) + if data_format == DATA_FORMAT_NCHW: + outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2))) + + # if data_format == DATA_FORMAT_NHWC: + # inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis))) + # inputs_nchw_shape = inputs.shape + # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:]) + # if inputs.shape.ndims != 4: + # # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] + # if inputs.shape.ndims > 4: + # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]] + # else: + # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1] + # inputs = array_ops.reshape(inputs, inputs_ndims4_shape) + # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) + # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) + # + # outputs, _, _ = nn.fused_batch_norm( + # inputs, gamma, beta, epsilon=epsilon, + # data_format=DATA_FORMAT_NCHW, name='instancenorm') + # + # outputs = array_ops.reshape(outputs, inputs_nchw_shape) + # if data_format == DATA_FORMAT_NHWC: + # outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1]) if activation_fn is not None: outputs = activation_fn(outputs)