Batchnorm channels last

Hi,

I was just implementing a simple 2d batchnorm and wanted to use channels last format. When I use the code as pasted below, my GPU profiler NSight shows the forward kernels using the channels last format as indicated by their names. But the backward uses the basic batchnorm_backward_reduce kernel instead of what is expected to be used i.e batch_norm_backward_reduce_channels_last_kernel. Is there something wrong with my initialization or is this expected behavior?

import torch
import torch.nn as nn
import torch.cuda.profiler as profiler
import pyprof
import argparse
pyprof.init()
import torch.cuda.profiler as profiler

parser = argparse.ArgumentParser(description=‘Profile a batch normalization layer.’)
parser.add_argument(‘n’, type=int)
parser.add_argument(‘c’, type=int)
parser.add_argument(‘h’, type=int)
parser.add_argument(‘w’, type=int)

args = parser.parse_args()

input_tensor = torch.rand(args.n, args.c, args.h, args.w).half().cuda()-0.5
input_tensor.requires_grad=True
input_tensor = input_tensor.to(memory_format=torch.channels_last)

batchnorm = nn.BatchNorm2d(args.c).half().cuda()
loss_fxn = nn.MSELoss()
target_tensor = torch.rand(input_tensor.shape).half().cuda()
target_tensor = target_tensor.to(memory_format=torch.channels_last)

print(“Profiling input tensor:”,input_tensor.shape)
print(“On layernorm layer:”, batchnorm)

for i in range(25):
output = batchnorm(input_tensor).to(memory_format=torch.channels_last)
loss = loss_fxn(output,target_tensor)
loss.backward()

I guess at::native::batch_norm_backward_reduce_kernel is used as you are reducing the loss, so the gradients are not passed as a 4D channels-last tensor.
If you use output.backward(torch.ones_like(output)) you would get:

at::native::batch_norm_backward_reduce_channels_last_kernel
at::native::batch_norm_backward_elemt_channels_last_kernel

@ptrblck : Sorry. Its not clear why transforming the output using ones_like helps in this case. Won’t that convert the tensor into 1’s instead of retaining the original output?.

Sorry. My bad. I guess what you were saying is we need to explicitly pass the gradient to backward which can be initialized to ones. Thanks.

@ptrblck : At the same time, we normally call backward on the loss instead of the output correct?. Does it mean that in a real network for example, we will see the backward pass of batchnorm not applying the channels last format?

Yes, that’s right.

No, if you add another layer (e.g. nn.Conv2d) you would see the channels-last kernels again. The issue with your code would be the reduction, which would drop the memory layout, if I’m not mistaken.

1 Like