Why does rotating both input and kernel not give rotated output in conv2d?

Hi,

I have the following minimal code example:

import torch 
import torch.nn.functional as F  

x = torch.rand(1 , 1, 100, 100) - 0.5 
w = torch.rand(1 , 1, 5, 5) - 0.5  
y1 = F.conv2d(x, w, stride=1, padding=0)  
x90 = torch.rot90(x, 1, (2,3)) 
w90 = torch.rot90(w, 1, (2,3)) 
y2 = F.conv2d(x90, w90, stride=1, padding=0)  
y1_rot = torch.rot90(y1, 1, (2,3))  
print(torch.allclose(y2, y1_rot))  # returns False  

My expectation:

  • If I rotate the input by 90° and also rotate the kernel by 90°, then apply convolution, the result should be the rotated version of the original convolution output.

  • In other words, I expected y2 == rot90(y1) up to floating point tolerance.

But in practice the code prints False.

Is this behavior expected in torch.nn.functional.conv2d?

Thanks in advance!

The testing thresholds might be too strict as it seems the allclose check passes with atol=1e-7.

1 Like

Thanks! You’re right.
In this case it was the threshold. Previously I had larger differences due to the stride (2,2) when the image had even dimensions.

Now the differences are very small, which are negligible in a single layer, but become more important when stacking multiple layers.

Something that caught my attention is that the cosine similarity is always greater between 180° rotations than between 90° rotations. Any idea how to analyze this?

Thanks again!

Here’s an example where you can see the behavior I mentioned before.

import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import seaborn as sns

s = 1
p = 0
err = np.zeros((4,4))

for a in range(1000):
  x0 = torch.rand(1 , 1, 100, 100)-0.5
  x90 = torch.rot90(x0, 1, (2,3))
  x180 = torch.rot90(x0, 2, (2,3))
  x270 = torch.rot90(x0, 3, (2,3))

  w0 = torch.rand(1 , 1, 5, 5)-0.5
  w90 = torch.rot90(w0, 1, (2,3))
  w180 = torch.rot90(w0, 2, (2,3))
  w270 = torch.rot90(w0, 3, (2,3))

  y0 = F.conv2d(x0, w0, stride=s, padding=p)
  y90 = F.conv2d(x90, w90, stride=s, padding=p)
  y180 = F.conv2d(x180, w180, stride=s, padding=p)
  y270 = F.conv2d(x270, w270, stride=s, padding=p)

  y90_r = torch.rot90(y90, -1, (2,3))
  y180_r = torch.rot90(y180, -2, (2,3))
  y270_r = torch.rot90(y270, -3, (2,3))

  outputs = [y0, y90_r, y180_r, y270_r]
  for i in range(4):
    for j in range(4):
      err[i,j] = err[i,j] + torch.abs(outputs[i] - outputs[j]).sum()
print(err)
plt.figure(figsize=(3,3))
sns.heatmap(err, cmap='gray', cbar=False)
plt.show()

You’ll always see the same pattern show up in the error matrix. The largest errors are at (0,2) and (1,3).