Issue with BatchNorm + Conv that depends on dtype

Hello,

During experiments came across the following issue.

I have custom convolution class that based on torch.nn.Conv2d implementation.

class FastConv(object):

  @staticmethod
  def forward(x, w, b, conv_param):
    # print(x.shape, w.shape, b.shape, conv_param)
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    stride, pad = conv_param['stride'], conv_param['pad']
    layer = torch.nn.Conv2d(C, F, (HH, WW), stride=stride, padding=pad)
    layer.weight = torch.nn.Parameter(w)
    layer.bias = torch.nn.Parameter(b)
    tx = x.detach()
    tx.requires_grad = True
    out = layer(tx)
    cache = (x, w, b, conv_param, tx, out, layer)
    return out, cache

  @staticmethod
  def backward(dout, cache):
    try:
      x, _, _, _, tx, out, layer = cache
      out.backward(dout)
      dx = tx.grad.detach()
      dw = layer.weight.grad.detach()
      db = layer.bias.grad.detach()
      layer.weight.grad = layer.bias.grad = None
    except RuntimeError:
      dx, dw, db = torch.zeros_like(tx), torch.zeros_like(layer.weight), torch.zeros_like(layer.bias)
    return dx, dw, db

The purpose of this FastConv class is the ability to explicitly pass weights and biases each forward call and store intermediate tensors for the custom backprop (for pedagogical aims).

Apart from it, there lives a code that performs batch normalization in two ways: directly performsBatchNorm2d and the other one that simulates it by reshaping input tensor and using BatchNorm1d.

device = 'cuda'
num_inputs = 2
input_dims = (3, 16, 16)
next_filt = 16

batchnorm = True
dtype = torch.float32

kernel_size = 3
bn_param = {'mode': 'train'}
# stride and padding preserve output spatial size
conv_param = {'stride': 1, 'pad': (kernel_size - 1) // 2}


x = torch.randn(num_inputs, *input_dims, dtype=dtype, device=device)

gamma = torch.ones(input_dims[0], device=device, dtype=dtype)
beta = torch.zeros(input_dims[0], device=device, dtype=dtype)

Weight = torch.randn(next_filt, input_dims[0], kernel_size, kernel_size, dtype=dtype, device=device)
b = torch.zeros(next_filt, dtype=dtype, device=device)

N, C, H, W = x.shape

## PyTorch BN2d
try:
  out = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(x)

  out_bn_2d, _ = FastConv.forward(out, Weight, b, conv_param)
except Exception as e:
  print(e)

## Pytorch BN1d
try:
  ch_view = x.transpose(1,2).transpose(2,3).reshape(N * H * W, C) 
  out = torch.nn.BatchNorm1d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(ch_view)
  out = out.reshape(N, H, W, C).transpose(2,3).transpose(1,2)

  out_bn_1d, _ = FastConv.forward(out, Weight, b, conv_param)
except Exception as e:
  print(e)

The thing is that if dtype=torch.float64 everything in this setup works fine, but if we change dtype=torch.float32 only BatchNorm2d implementation works. While the other fails with:

set_sizes_and_strides is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
    x.data.set_(y)
to:
    with torch.no_grad():
        x.set_(y)

Reproduce code is in colab change dtype in third cell to see difference.

[UPD] Issue won’t occur not only if dtype==torch.float64 on cuda but with float32 on cpu.

Please help with this strange behaviour

Hello @ptrblck, could you please have a look

Both approaches work in my setup:

class FastConv(object):
  @staticmethod
  def forward(x, w, b, conv_param):
    # print(x.shape, w.shape, b.shape, conv_param)
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    stride, pad = conv_param['stride'], conv_param['pad']
    layer = torch.nn.Conv2d(C, F, (HH, WW), stride=stride, padding=pad)
    layer.weight = torch.nn.Parameter(w)
    layer.bias = torch.nn.Parameter(b)
    tx = x.detach()
    tx.requires_grad = True
    out = layer(tx)
    cache = (x, w, b, conv_param, tx, out, layer)
    return out, cache

  @staticmethod
  def backward(dout, cache):
    try:
      x, _, _, _, tx, out, layer = cache
      out.backward(dout)
      dx = tx.grad.detach()
      dw = layer.weight.grad.detach()
      db = layer.bias.grad.detach()
      layer.weight.grad = layer.bias.grad = None
    except RuntimeError:
      dx, dw, db = torch.zeros_like(tx), torch.zeros_like(layer.weight), torch.zeros_like(layer.bias)
    return dx, dw, db


device = 'cuda'
num_inputs = 2
input_dims = (3, 16, 16)
next_filt = 16

batchnorm = True
dtype = torch.float32

kernel_size = 3
bn_param = {'mode': 'train'}
# stride and padding preserve output spatial size
conv_param = {'stride': 1, 'pad': (kernel_size - 1) // 2}

x = torch.randn(num_inputs, *input_dims, dtype=dtype, device=device)

gamma = torch.ones(input_dims[0], device=device, dtype=dtype)
beta = torch.zeros(input_dims[0], device=device, dtype=dtype)

Weight = torch.randn(next_filt, input_dims[0], kernel_size, kernel_size, dtype=dtype, device=device)
b = torch.zeros(next_filt, dtype=dtype, device=device)

N, C, H, W = x.shape

## PyTorch BN2d
out = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(x)
out_bn_2d, _ = FastConv.forward(out, Weight, b, conv_param)
print(out_bn_2d.shape)
# > torch.Size([2, 16, 16, 16])

## Pytorch BN1d
ch_view = x.transpose(1,2).transpose(2,3).reshape(N * H * W, C) 
out = torch.nn.BatchNorm1d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(ch_view)
out = out.reshape(N, H, W, C).transpose(2,3).transpose(1,2)
out_bn_1d, _ = FastConv.forward(out, Weight, b, conv_param)
print(out_bn_1d.shape)
# > torch.Size([2, 16, 16, 16])

If you are using an older PyTorch release, could you update to the latest stable or nightly?

Hello again,

Yes, this solves the issue.

I’ve changed default colab’s version 1.10.0+cu111 to 1.12.0.dev20220209+cu111 and obtained proper behavior.

Thank you