I was trying to implement a few versions of local image normalization, all involving some variation of a Gaussian blur, then subtracting that from the original image. I kept getting odd results such as occasional images filled with all 0s or all -1s or similar. After some investigation, I was able to narrow it down to a minimal example to reproduce the bug.
It turns out that torchvision.transforms.GaussianBlur will occasionally return a tensor identical to its input, seemingly at random. Not only that, but it doesn’t appear to behave deterministicly, even with torch.use_deterministic_algorithms(True) set.
I’ve made a colab notebook to demonstrate the issue. It initializes an image tensor, clones it 64 times, and runs it through a simplified local normalization function. This function just blurs the image using torchvision.transforms.GaussianBlur, and subtracts the original form it. Not only does the output of this operation vary by several orders of magnitude, even when using deterministic algorithms, the output is occasionally 0, meaning that applying a gaussian(img) == img approximately 2/64 times. This behavior persists on both CPU and GPU.
This very much looks like a bug to me, but perhaps someone who knows Pytorch’s internals better than me can explain something I’m missing.
The code from the notebook is reproduced below:
import torch
import torchvision
import matplotlib
import matplotlib.pyplot as plt
torch.use_deterministic_algorithms(True)
print(torch.__version__)
def localNorm(img):
img = img.float()
#img = img.to("cuda") ##Bug exists on both CPU and GPU
blur = torchvision.transforms.GaussianBlur(9)
img_b = blur(img)
#assert(not (img_b == img).all()) ##Optional assert to catch the bug here, rather than comparing the sum to 0 later
img_b = img_b - img ##subtract blurred image from the original
return img_b.cpu().detach()
##three different ways of initializing the image. Two random, and one deterministic. All exhibit the bug
#r = torch.rand(1,2160,2560)
#r = torch.normal(0,1,(1,2160,2560))
r = torch.linspace(0,2160*2560,2160*2560)#*(3.141592653589)
r = torch.cos(r)
r = torch.unflatten(r,0,(1,2160,2560))
##intit a tensor 'r' representing an image, and clone it 64 times
dataset = [torch.clone(r) for i in range(64)]
fig, axarr = plt.subplots(8,8)
means = []
##run localNorm on the identical image 64 times, and observe differing results, with the occasional identical (zero sum of differences) result
for i,img in enumerate(dataset):
norm = localNorm(img)[0]
axarr[i//8,i%8].imshow(norm)#, cmap='gray')
mean = torch.sum(norm)
means.append(mean)
print(mean.item(), (mean==0).item())
plt.gca().set_axis_off()
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
plt.margins(0,0)
for ax in fig.axes:
ax.axis('off')
plt.show()
plt.plot(means)
plt.show()