What's the difference in gradient backprop between slicing and nn.Unfold?

My code used to contain loop slicing which was slow. Recently I changed it into nn.Unfold which computes the same reult. However, I found that my model converged into a worse status, whose evaluation metrics deteriorates. After inspection, I found that although they output the same result, the gradiet produced by back propagation are different. What should I do? My envirnment is python==3.6.10, torch==1.4.0. Here is a simple code snippet reproducing the difference:

import torch
import torch.nn as nn
import torch.nn.functional as F

class simple_net(nn.Module):
    def __init__(self):
        super(simple_net, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
        self.conv2 = nn.Conv2d(8, 3, 3, 1, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

model = simple_net().cuda()
unfold = nn.Unfold(3, padding=0)
x = torch.randn([6,3,8,8], dtype=torch.float32).cuda()
x = torch.autograd.Variable(x)
result = model(x)
b, c, h, w = result.shape
result_pad = F.pad(result, [1,1,1,1])
result_pad.retain_grad()
result_pad_1 = result_pad.clone()
result_pad_1.retain_grad()
result_unfold = unfold(result_pad)
slice_list = []
for i in range(h):
    for j in range(w):
        slicetmp = result_pad_1[..., i:i+3, j:j+3].reshape(b, -1)
        slice_list.append(slicetmp)
result_slice = torch.stack(slice_list, dim=2)
print("forward result:, ", str(torch.sum(torch.abs(result_unfold - result_slice))))
loss = result_unfold.abs().sum() + result_slice.abs().sum()
loss.backward()
print("gradient result:, ", str(torch.sum(torch.abs(result_pad.grad - result_pad_1.grad))))

outputs:

forward result:,  tensor(0., device='cuda:0', grad_fn=<SumBackward0>)  
gradient result:,  tensor(8712., device='cuda:0') 

Many thanks in advance!

The gradient results may be different due to how you are testing: result_pad_1 is a function of result_pad so when you compute the gradient wrt result_pad you’ll accumulate into the .grad of result_pad_1. Do you observe that one gradient is twice the value of the other?

(btw is there a reason why you haven’t upgraded your pytorch version?)

If you’d like to know whether the two outputs have the same value, computing the sum is one heuristic, but you can also use torch.allclose(a, b)

If you’d like to compute whether two functions compute the same gradient. You should use grad_inp = torch.autograd.grad(output, inp) to compute the gradients of inputs wrt each of the outputs respectively, and then compare them with torch.allclose.