diff --git a/video_prediction/layers/normalization.py b/video_prediction/layers/normalization.py
index 453391e002eed200c142eb131bc5a8f76ed5d1ab..2f2a9bb9c0cf41eb7ed10f25e606c06791e0d2a3 100644
--- a/video_prediction/layers/normalization.py
+++ b/video_prediction/layers/normalization.py
@@ -144,26 +144,52 @@ def fused_instance_norm(inputs,
         gamma = array_ops.reshape(gamma, params_shape_broadcast)
 
     if data_format == DATA_FORMAT_NHWC:
-      inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis)))
-    inputs_nchw_shape = inputs.shape
-    inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:])
+      inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis])
+    if data_format == DATA_FORMAT_NCHW:
+      inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis])
+    hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value
+    inputs = array_ops.reshape(inputs, [1] + hw + [n * c])
     if inputs.shape.ndims != 4:
         # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
         if inputs.shape.ndims > 4:
-            inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]]
+            inputs_ndims4_shape = [1, hw[0], -1, n * c]
         else:
-            inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1]
+            inputs_ndims4_shape = [1, 1, -1, n * c]
         inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
-    beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
-    gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
+    beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1])
+    gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1])
 
     outputs, _, _ = nn.fused_batch_norm(
         inputs, gamma, beta, epsilon=epsilon,
-        data_format=DATA_FORMAT_NCHW, name='instancenorm')
+        data_format=DATA_FORMAT_NHWC, name='instancenorm')
 
-    outputs = array_ops.reshape(outputs, inputs_nchw_shape)
+    outputs = array_ops.reshape(outputs, hw + [n, c])
     if data_format == DATA_FORMAT_NHWC:
-      outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1])
+      outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1])
+    if data_format == DATA_FORMAT_NCHW:
+      outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2)))
+
+    # if data_format == DATA_FORMAT_NHWC:
+    #   inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis)))
+    # inputs_nchw_shape = inputs.shape
+    # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:])
+    # if inputs.shape.ndims != 4:
+    #     # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W]
+    #     if inputs.shape.ndims > 4:
+    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]]
+    #     else:
+    #         inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1]
+    #     inputs = array_ops.reshape(inputs, inputs_ndims4_shape)
+    # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
+    # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1])
+    #
+    # outputs, _, _ = nn.fused_batch_norm(
+    #     inputs, gamma, beta, epsilon=epsilon,
+    #     data_format=DATA_FORMAT_NCHW, name='instancenorm')
+    #
+    # outputs = array_ops.reshape(outputs, inputs_nchw_shape)
+    # if data_format == DATA_FORMAT_NHWC:
+    #   outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1])
 
     if activation_fn is not None:
       outputs = activation_fn(outputs)