Hello,
We noticed that with a model in NHWC memory layout, and an explicit conversion around BatchNorm2d, the backward() call seems to have some interesting output (measured with hooks). The gradient wrt input for the last layer of the 2 layer model is NHWC, but the gradient wrt output for the first layer is NCHW.
The issue does not exist if the explicit conversions are omitted and BatchNorm2d is allowed to be in NHWC. Can anyone help explain why the gradient tensor memory layout seems to be switched in this reproducer ?
Reproducer
import torch
def is_channels_last(tensor):
if tensor is not None:
return tensor.is_contiguous(memory_format=torch.channels_last)
else:
return None
def check(tensor):
if isinstance(tensor, tuple):
res = ""
for i, t in enumerate(tensor):
res += "{}".format(is_channels_last(t))
if i < len(tensor) - 1:
res += ", "
elif tensor is None:
res = None
else:
res = is_channels_last(tensor)
return res
def fw_hook(module, input, output):
if isinstance(module, torch.nn.Conv2d):
print("Debug: forward module {}, input {}, output {}, weight {}, bias {}".format(
module, check(input), check(output), check(module.weight), check(module.bias)))
else:
print("Debug: forward module {}, input {}, output {}".format(
module, check(input), check(output)))
def bw_hook(module, grad_wrt_input, grad_wrt_output):
if isinstance(module, torch.nn.Conv2d):
print("Debug: backward module {}, grad_wrt_input {}, grad_wrt_output {}, weight {}, bias {}".format(
module, check(grad_wrt_input), check(grad_wrt_output), check(module.weight), check(module.bias)))
else:
print("Debug: backward module {}, grad_wrt_input {}, grad_wrt_output {}".format(
module, check(grad_wrt_input), check(grad_wrt_output)))
class BnAddRelu(torch.nn.BatchNorm2d):
def __init__(self, planes, fuse_relu=False, bn_group=1):
super(BnAddRelu, self).__init__(planes)
self.fuse_relu_flag = fuse_relu
def forward(self, x, z=None):
x = x.to(memory_format=torch.contiguous_format)
out = super().forward(x)
if z is not None:
z = z.to(memory_format=torch.contiguous_format)
out = out.add_(z)
if self.fuse_relu_flag:
out = out.relu_()
out = out.to(memory_format=torch.channels_last)
return out
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = torch.nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn = BnAddRelu(128)
def forward(self, input):
out = self.conv(input)
out = self.bn(out)
return out
model = Model().to(memory_format=torch.channels_last).cuda()
for _, l in model._modules.items():
l.register_forward_hook(fw_hook)
l.register_backward_hook(bw_hook)
input = torch.rand(120, 128, 38, 38, dtype=torch.float, device="cuda").to(memory_format=torch.channels_last)
output = model(input)
output.backward(torch.rand_like(output))
Output
$ python test_conv_bw.py
Debug: forward module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), input True, output True, weight True, bias None
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Debug: forward module BnAddRelu(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), input True, output True
Debug: backward module BnAddRelu(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), grad_wrt_input **True**, grad_wrt_output True
Debug: backward module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), grad_wrt_input None, True, grad_wrt_output **False**, weight True, bias None