I’ve tried and compared three different methods for convolution computation with a custom kernel in Pytorch. Their results are different but I don’t understand why that is.
Setup code:
import torch
import torch.nn.functional as F
inp = torch.arange(3*500*700).reshape(1,3,500,700).to(dtype=torch.float32)
wgt = torch.ones((1,3,3,3)).to(dtype=torch.float32)
stride = 1
padding = 0
h = inp.shape[2] - wgt.shape[2] + 1
w = inp.shape[3] - wgt.shape[3] + 1
Method 1
out1 = torch.zeros((1,h,w)).to(dtype=torch.float32)
for o in range(1):
for i in range(3):
for j in range(h):
for k in range(w):
out1[o,j,k] = out1[o,j,k] + (inp[0, i, j*stride:j*stride+3, k*stride:k*stride+3] * wgt[0,i]).sum()
out1 = out1.to(dtype=torch.int)
Method 2
inp_unf = F.unfold(inp, (3,3))
out_unf = inp_unf.transpose(1,2).matmul(wgt.view(1,-1).t()).transpose(1,2)
out2 = F.fold(out_unf, (h,w), (1,1))
out2 = out2.to(dtype=torch.int)
Method 3
out3 = F.conv2d(inp, wgt, bias=None, stride=1, padding=0)
out3 = out3.to(dtype=torch.int)
And here are the results comparison:
>>> h*w
347604
>>> (out1==out2).sum().item()
327338
>>> (out2 == out3).sum().item()
344026
>>> (out1 == out3).sum().item()
330797
>>> out1.shape
(1, 498, 698)
>>> out2.shape
(1, 1, 498, 698)
>>> out3.shape
(1, 1, 498, 698)
Their data types are all int
so floating point won’t the result. When I use a squared input format such as h=500
and w=500
, all three results are all matching. But not for non-squared inputs, such as the one above with h=500
and w=700
. Any insight?