BatchNorm2d leading to large output differences in eval-mode based on batch size

I have a simple implementation of resnet, composed of {BatchNorm2d, Conv2d, relu, and Linear}. I trained it with some data, saved it to disk via jit.trace(), reloaded it into memory via jit.load(), put it in eval-mode, and passed an input tensor whose first dimension is batch_size, consisting of all zeros. My expectation is that the first row of the output should not vary significantly as I vary batch_size. However, I instead see very large differences (>1e-3) based on batch_size.

To reproduce:

git clone https://github.com/shindavid/pytorch-issue.git
cd pytorch-issue
python demo.py model.pt

Here are the contents of the demo.py script:

"""
python demo.py model.pt
"""
import random
import sys

import numpy as np
import torch

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(linewidth=200)
torch.use_deterministic_algorithms(True)

filename = sys.argv[1]
print('Testing: ' + filename)
net = torch.jit.load(filename)
net.to('cuda')
net.eval()
torch.set_grad_enabled(False)


def get_output(batch_size):
    input_tensor = torch.zeros((batch_size, 2, 7, 6)).to('cuda', non_blocking=True)
    output_tuple = net(input_tensor)
    output_tensor = output_tuple[0]
    return output_tensor[:1].to('cpu')


out1 = get_output(1)
failed = False
for b in range(2, 64):
    out = get_output(b)
    if torch.all(out == out1):
        pass
    else:
        failed = True
        print('Batch size {} is NOT OK. Diffs: {}'.format(b, out - out1))


if not failed:
    print('All ok!')

The output includes lines like this, demonstrating differences that exceed 1e-3:

Batch size 63 is NOT OK. Diffs: tensor([[ 0.0014, -0.0002,  0.0013, -0.0009,  0.0013,  0.0002,  0.0004]])

I can provide more description of the architecture of the model if needed, although anyone can clone the above repo and inspect the model directly. One observation I made is that if I train an alternative model that has all nn.BatchNorm2d layers removed, then I do not observe this batch-size-dependent output behavior, hence the title of this post.

Other observations:

  • If I keep all values on CPU, the differences become smaller (from 1e-3 to ~1e-7).
  • Loading the same model in an equivalent c++ program leads to the same output values.
  • The value of net.parameters() appears to never change as a result of any get_output() call. Specifically, the signature of net never changes, where the signature is defined by:
def signature(net):
    return tuple(tuple(map(float, p.flatten())) for p in net.parameters())

I am using pytorch 1.12.1 and CUDA 11.6.

I found some seemingly related posts ([1], [2], [3]), but none seem to explain what I am observing.

This issue is causing some major issues in my research project, so any solutions would be much appreciated!

1 Like

It would be helpful if the repro could be isolated to as few layers as possible that show numerical differences.

Could you also share more details about the setup such as the GPU being used (e.g., is it compute capability >= 8.0 which could be using TF32 for convolutions)?

It would be helpful if the repro could be isolated to as few layers as possible that show numerical differences.

I added a new file to the above git repo, called full_demo.py, which builds the network from scratch rather than loading a jit-compiled model from disk. In it, you can see the exact architecture of the model. There is no actual training, only evaluation of random input tensors in train-mode to initialize the batch norm layers.

As a bonus, full_demo.py rules out jit as a culprit.

You can specify the number of residual block layers to add to the network by command line:

python full_demo.py 0   # 0 residual blocks, this is as few layers as possible
python full_demo.py 19  # this matches the size of model.pt

When you pass in 19, you get differences of the same order of magnitude (1e-3). When you pass in 0, you get much smaller differences (1e-8).

For completeness, here is the definition of the Net module that represents the resnet:

class ConvBlock(nn.Module):
    def __init__(self, n_input_channels: int, n_conv_filters: int):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(n_input_channels, n_conv_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch = nn.BatchNorm2d(n_conv_filters)

    def forward(self, x):
        return F.relu(self.batch(self.conv(x)))


class ResBlock(nn.Module):
    def __init__(self, n_conv_filters: int):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(n_conv_filters, n_conv_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch1 = nn.BatchNorm2d(n_conv_filters)
        self.conv2 = nn.Conv2d(n_conv_filters, n_conv_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.batch2 = nn.BatchNorm2d(n_conv_filters)

    def forward(self, x):
        identity = x
        out = F.relu(self.batch1(self.conv1(x)))
        out = self.batch2(self.conv2(out))
        out += identity  # skip connection
        return F.relu(out)


class PolicyHead(nn.Module):
    def __init__(self, n_input_channels: int):
        super(PolicyHead, self).__init__()
        self.conv = nn.Conv2d(n_input_channels, 2, kernel_size=1, stride=1, bias=False)
        self.batch = nn.BatchNorm2d(2)
        self.linear = nn.Linear(84, 7)

    def forward(self, x):
        x = self.conv(x)
        x = self.batch(x)
        x = F.relu(x)
        x = x.view(-1, 84)
        x = self.linear(x)
        return x


class Net(nn.Module):
    def __init__(self, n_res_blocks=19, n_conv_filters=64):
        super(Net, self).__init__()
        self.n_conv_filters = n_conv_filters
        self.n_res_blocks = n_res_blocks
        self.conv_block = ConvBlock(2, n_conv_filters)
        self.res_blocks = nn.ModuleList([ResBlock(n_conv_filters) for _ in range(n_res_blocks)])
        self.policy_head = PolicyHead(n_conv_filters)

    def forward(self, x):
        x = self.conv_block(x)
        for block in self.res_blocks:
            x = block(x)
        return self.policy_head(x),

Could you also share more details about the setup such as the GPU being used (e.g., is it compute capability >= 8.0 which could be using TF32 for convolutions)?

The GPU is an NVIDIA GeForce RTX 3090.

I want to reiterate that I see non-negligible differences even on CPU. They are smaller (1e-7), but I think that such differences are too large to blame on floating point arithmetic issues.

I added this to the top of full_demo.py:

torch.set_default_dtype(torch.float64)

With that, the diffs shrink from 1e-3 to 1e-16.

So it appears that this is due to floating point error, compounded by the multiple layers of the network.

I am quite surprised that a vanilla network of moderate depth can experience floating point discrepancies as large as 1e-3, but I guess that’s how these things work. For whatever reason, the BatchNorm layers seem particularly susceptible, as replacing them with Identity leads to absolute-zero discrepancies.

1 Like