Shuffling the input before the model and shuffling the output after the model being not consistent on CUDA

Shuffling the input before feeding it into the model and shuffling the output the model output produces different outputs.
I think it has something to do with GPU and batch norm since the problem only happens in train mode only on CUDA not CPU.
Does anyone know why this is happening?

import torch
import torchvision.models as models

model = models.resnet50()

model = model.cuda()
idx = torch.randperm(128).cuda()
x = torch.randn(128, 3, 224, 224).cuda()

print("input indexing", model(x[idx]))
print("output indexing", model(x)[idx])

input indexing tensor([[-0.8570, 0.8326, -0.5860, …, 0.8055, -0.0580, 0.1908],
[-1.0155, 0.6950, -0.3823, …, 0.7518, 0.2098, 0.2948],
[-0.9844, 0.6219, -0.2615, …, 0.8177, 0.1248, 0.2177],
…,
[-0.9578, 0.5173, -0.1964, …, 0.6827, 0.0789, 0.2514],
[-1.0569, 0.5965, -0.3957, …, 0.6497, 0.0286, 0.3090],
[-1.0232, 0.6058, -0.2303, …, 0.6736, 0.1143, 0.4659]],
device=‘cuda:0’, grad_fn=)
output indexing tensor([[-0.8633, 0.8343, -0.5875, …, 0.8065, -0.0620, 0.1971],
[-1.0178, 0.6932, -0.3813, …, 0.7535, 0.2155, 0.2961],
[-0.9846, 0.6146, -0.2619, …, 0.8111, 0.1284, 0.2190],
…,
[-0.9474, 0.5150, -0.1974, …, 0.6781, 0.0821, 0.2537],
[-1.0614, 0.5949, -0.3894, …, 0.6516, 0.0253, 0.3094],
[-1.0255, 0.6015, -0.2268, …, 0.6716, 0.1218, 0.4665]],
device=‘cuda:0’, grad_fn=)

Output is like this above.

But the mean and variance for each feature channel will not be changed whether the input is shuffled or not.
If idx is only selecting a subset of the batch, the batch statistics will change but idx is just permuting the entire batch.

nn.BatchNorm2d layers will take the stats from dims [0, 2, 3] and indeed shuffling the samples should not change the applied function. However, you are changing the order of operations (or rather samples in this case), which might show small errors due to the limited floating point precision order.
Here is a small example using only the sum operation:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0)
s3 = x[torch.randperm(x.size(0))].sum()
print(s1 - s2)
# tensor(-5.7220e-06)
print(s1 - s3)
# tensor(1.1444e-05)

All three approaches calculate the same desired sum but due to a change in the operation/sample order the outputs show small errors and you might see the same effect in your model.

1 Like