NHWC backward with explicit format conversions

Hello,

We noticed that with a model in NHWC memory layout, and an explicit conversion around BatchNorm2d, the backward() call seems to have some interesting output (measured with hooks). The gradient wrt input for the last layer of the 2 layer model is NHWC, but the gradient wrt output for the first layer is NCHW.

The issue does not exist if the explicit conversions are omitted and BatchNorm2d is allowed to be in NHWC. Can anyone help explain why the gradient tensor memory layout seems to be switched in this reproducer ?

Reproducer

import torch

def is_channels_last(tensor):
    if tensor is not None:
        return tensor.is_contiguous(memory_format=torch.channels_last)
    else:
        return None

def check(tensor):
    if isinstance(tensor, tuple):
        res = ""
        for i, t in enumerate(tensor):
            res += "{}".format(is_channels_last(t))
            if i < len(tensor) - 1:
                res += ", "
    elif tensor is None:
        res = None
    else:
        res = is_channels_last(tensor)
    return res

def fw_hook(module, input, output):
    if isinstance(module, torch.nn.Conv2d):
        print("Debug: forward module {}, input {}, output {}, weight {}, bias {}".format(
            module, check(input), check(output), check(module.weight), check(module.bias)))
    else:
        print("Debug: forward module {}, input {}, output {}".format(
            module, check(input), check(output)))

def bw_hook(module, grad_wrt_input, grad_wrt_output):
    if isinstance(module, torch.nn.Conv2d):
        print("Debug: backward module {}, grad_wrt_input {}, grad_wrt_output {}, weight {}, bias {}".format(
            module, check(grad_wrt_input), check(grad_wrt_output), check(module.weight), check(module.bias)))
    else:
        print("Debug: backward module {}, grad_wrt_input {}, grad_wrt_output {}".format(
            module, check(grad_wrt_input), check(grad_wrt_output)))

class BnAddRelu(torch.nn.BatchNorm2d):
    def __init__(self, planes, fuse_relu=False, bn_group=1):
        super(BnAddRelu, self).__init__(planes)

        self.fuse_relu_flag = fuse_relu

    def forward(self, x, z=None):
        x = x.to(memory_format=torch.contiguous_format)
        out = super().forward(x)
        if z is not None:
            z = z.to(memory_format=torch.contiguous_format)
            out = out.add_(z)
        if self.fuse_relu_flag:
            out = out.relu_()
            out = out.to(memory_format=torch.channels_last)
        return out

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.conv = torch.nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn = BnAddRelu(128)

    def forward(self, input):
        out = self.conv(input)
        out = self.bn(out)
        return out

model = Model().to(memory_format=torch.channels_last).cuda()

for _, l in model._modules.items():
    l.register_forward_hook(fw_hook)
    l.register_backward_hook(bw_hook)

input = torch.rand(120, 128, 38, 38, dtype=torch.float, device="cuda").to(memory_format=torch.channels_last)
output = model(input)
output.backward(torch.rand_like(output))

Output

$ python test_conv_bw.py
Debug: forward module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), input True, output True, weight True, bias None
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Debug: forward module BnAddRelu(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), input True, output True
Debug: backward module BnAddRelu(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), grad_wrt_input **True**, grad_wrt_output True
Debug: backward module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), grad_wrt_input None, True, grad_wrt_output **False**, weight True, bias None