The first comparison looks alright as the relative error is in the expected range (~1e-7
) for float32
. You can use integer weights to verify it:
with torch.no_grad():
conv.weight.copy_(torch.randint(-1, 1, (conv.weight.nelement(),)).view_as(conv.weight))
The unfold
approach is wrong and this post gives you an example using a manual unfold
approach.
Fixing it yields no mismatches:
patches = x.unfold(2, K, stride).unfold(3, K, stride)
patches = patches.contiguous().view(B, C, -1, K, K)
patches = patches.permute(0, 2, 1, 3, 4)
torch_w_reshaped = torch_w.unsqueeze(1).unsqueeze(1).unsqueeze(1)
vect_output = torch.mul(patches.unsqueeze(2), torch_w.unsqueeze(0).unsqueeze(1)).sum(dim=(-1,-2,-3))
vect_output = vect_output.permute(0, 2, 1)
vect_output = vect_output.view(B, -1, out_image.size(2), out_image.size(3))
torch.sum((out_image - torch_conv_out)**2)**0.5 # tensor(0., grad_fn=<PowBackward0>)
torch.sum((vect_output - out_image)**2)**0.5 # tensor(0.)
torch.sum((vect_output - torch_conv_out)**2)**0.5 # tensor(0., grad_fn=<PowBackward0>)