Different results for forward pass of two equal tensors

Hi,

forwarding a tensor img through a PyTorch model produces a different result than forwarding img + torch.zeros_like(img)?

Here is a minimal example: https://github.com/dozed/pytorch-issue-1/blob/main/test_issue.py

The relevant code is:

def do_forward_pass(
        model: nn.Module,
        add_zeros: bool,
) -> Tensor:
    # prepare input
    img = torchvision.io.image.read_image('sky1024px.jpg')
    img = FT.convert_image_dtype(img, torch.float32)
    img = img.unsqueeze(dim=0)
    img = img.to(device)

    # maybe add zero tensor
    if add_zeros:
        zeros = torch.zeros_like(img)
        img_updated = img + zeros
    else:
        img_updated = img

    assert torch.allclose(img, img_updated)
    assert torch.equal(img, img_updated)

    # forward pass
    result = model(img_updated)

    return result


result1 = do_forward_pass(model, add_zeros=False)
result2 = do_forward_pass(model, add_zeros=True)

assert torch.allclose(result1, result2)
assert torch.equal(result1, result2)

result1 and result2 should be equal, but they are different.

Any ideas what causes this divergence?

Regards,
Stefan

Are you using any random layers, such as dropout? Also, how large is the difference?

The model does not use random layers:

The difference is as follows:

torch.linalg.norm(result1 - result2)
# tensor(6.4035e-06)

Are you seeing the same difference in two different runs without adding the zeros to the tensor?

Without adding zeros the result tensors of two runs are equal, i.e. the following assertions hold:

result1 = do_forward_pass(model, add_zeros=False)
result2 = do_forward_pass(model, add_zeros=False)

assert torch.allclose(result1, result2)
assert torch.equal(result1, result2)

I could further reduce the example case to a model with two Conv2d layers:

# prepare input
img = torchvision.io.image.read_image('sky1024px.jpg')
img = FT.convert_image_dtype(img, torch.float32)
img = img.unsqueeze(dim=0)

# prepare input + zero
zeros = torch.zeros_like(img)
img_updated = img + zeros

# input tensors are identical
assert torch.allclose(img, img_updated)
assert torch.equal(img, img_updated)

# prepare model
conv1 = nn.Conv2d(in_channels=3, out_channels=129, kernel_size=3, padding=1)
conv2 = nn.Conv2d(in_channels=129, out_channels=4, kernel_size=1)

# forward 1
x = conv1(img)
result1 = conv2(x)

# forward 2
y = conv1(img_updated)
result2 = conv2(y)

# ISSUE: the results are not equal but should be, since only zeros are added
print(torch.linalg.norm(result1 - result2))
assert torch.allclose(result1, result2)
assert torch.equal(result1, result2)

Example: https://github.com/dozed/pytorch-issue-1/blob/main/test_issue.py#L39-L42

I cannot reproduce the issue and get a zero error using a fixed random input:

import torch
import torchvision
from torch import nn
from torchvision.transforms import functional as FT


def test_add_zero_different_result():
    # prepare input
    torch.manual_seed(2809)
    img = torch.randn(1, 3, 224, 224)

    # prepare input + zero
    zeros = torch.zeros_like(img)
    img_updated = img + zeros

    # input tensors are identical
    assert torch.allclose(img, img_updated)
    assert torch.equal(img, img_updated)

    # prepare model
    conv1 = nn.Conv2d(in_channels=3, out_channels=129, kernel_size=3, padding=1)
    conv2 = nn.Conv2d(in_channels=129, out_channels=4, kernel_size=1)
    conv1.requires_grad_(False)
    conv2.requires_grad_(False)

    # forward 1
    x = conv1(img)
    result1 = conv2(x)

    # forward 2
    y = conv1(img_updated)
    result2 = conv2(y)

    # tensors after conv1 are equal
    assert torch.allclose(x, y)
    assert torch.equal(x, y)

    # ISSUE: the results are not equal but should be, since only zeros are added
    print(torch.linalg.norm(result1 - result2))
    assert torch.allclose(result1, result2)
    assert torch.equal(result1, result2)

for _ in range(100):
    test_add_zero_different_result()

# tensor(0.)
# tensor(0.)
# tensor(0.)
# tensor(0.)
...

Please try the exact code from the example repository. It works for random input, but not for the image input.

Also when you change the number of channels from 129 to 128, the test passes.

Another finding: without the img.unsqueeze(dim=0) the test passes

After some further debugging I could pinpoint why this happens.

torchvision.io.image.read_image produces a tensor which is the CHW permuted version of a HWC tensor. That means for a [3, 1, 2] sized CHW tensor the strides are [1, 6, 3].

def print_tensor_info(x: torch.Tensor) -> None:
    print('size:', x.size())
    print('stride:', x.stride())

orig = torch.tensor([
    [
        [1, 2, 3],
        [4, 5, 6],
    ],
])

print_tensor_info(orig)
# size: torch.Size([1, 2, 3])
# stride: (6, 3, 1)

permuted = orig.permute(2, 0, 1)

print_tensor_info(permuted)
# size: torch.Size([3, 1, 2])
# stride: (1, 6, 3)

unsqueeze of this permuted CHW tensor to a NCHW tensor sets an invalid batch stride of 3:

permuted_unsqueezed = permuted.unsqueeze(dim=0)

print_tensor_info(permuted_unsqueezed)
# size: torch.Size([1, 3, 1, 2])
# stride: (3, 1, 6, 3)

The correct value would be 6. There is only a simple heuristic which chooses that value in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TensorShape.cpp#L3191-L3199:

InferUnsqueezeGeometryResult
inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
  InferUnsqueezeGeometryResult result(tensor.sizes(), tensor.strides());
  int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
  result.sizes.insert(result.sizes.begin() + dim, 1);
  result.strides.insert(result.strides.begin() + dim, new_stride);

  return result;
}

This permuted_unsqueezed tensor with sizes [1, 3, 1, 2] and strides [3, 1, 6, 3] is treated as having a Contiguous memory format (and not a ChannelsLast memory format): https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TensorBase.h#L270-L289.

Adding zeros_like yields the following tensor:

zeros_like = torch.zeros_like(permuted_unsqueezed)

permuted_unsqueezed_added_zeros = permuted_unsqueezed.add(zeros_like)

print_tensor_info(permuted_unsqueezed_added_zeros)
# size: torch.Size([1, 3, 1, 2])
# stride: (6, 1, 6, 3)

Note that the stride [6, 1, 6, 3] is correct this time. Therefore the tensor is handled as having ChannelsLast memory format.

When both tensors are put through the same convolution layer, they will produce slightly different results since there are different cases for Contiguous and ChannelsLast tensors: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/ConvolutionMM2d.cpp#L247-L283:

if (is_channels_last) {
    // ...
  
    at::native::cpublas::gemm(...);
  } else {
    // ...

    at::native::cpublas::gemm(...);
}

The difference is in 10^-8, so in practice it should not be a problem.

Also this seems to be a known issue: https://github.com/pytorch/pytorch/issues/68430#issuecomment-970895522