Differing gradients with same algorithm

Hi all, I need a little bit of help figuring out where my implementation went wrong. I have an algorithm that runs on a 3D tensor, given by [batch, row, col]. I used to use a for loop to iterate through each n, where n is the number of rows/cols (n x n square matrix here). I figured out how to use the torch.scatter_add_ function to remove the for-loop.

Here’s the input to the function:

input = torch.Tensor([[[7,  8, 7, 20, 7],
                                   [15,  2,  2,  4,  8],
                                   [ 6, 17,  6,  7,  8],
                                   [ 1,  6, 12,  6,  2],
                                   [ 8,  1,  7,  2, 10]]]

Here’s the for-loop implementation:

def iter_simp_min_sum_batch(m_alpha_beta, m_beta_alpha, weights, n_iter):

    for _ in range(n_iter):

        # Message passing
        for i in range(n):
            alpha_beta_max = torch.max(torch.cat((m_alpha_beta[:, :i, :], m_alpha_beta[:, (i + 1):, :]), dim=1), dim=1)[0]
            beta_alpha_max = torch.max(torch.cat((m_beta_alpha[:, :i, :], m_beta_alpha[:, (i + 1):, :]), dim=1), dim=1)[0]
            if i == 0:
                m_beta_alpha_k = (weights[:, i, :] - alpha_beta_max).unsqueeze(2)
                m_alpha_beta_k = (weights[:, :, i] - beta_alpha_max).unsqueeze(2)
            else:
                m_beta_alpha_k = torch.cat((m_beta_alpha_k, (weights[:, i, :] - alpha_beta_max).unsqueeze(2)), dim=2)
                m_alpha_beta_k = torch.cat((m_alpha_beta_k, (weights[:, :, i] - beta_alpha_max).unsqueeze(2)), dim=2)

        m_alpha_beta = m_alpha_beta_k
        m_beta_alpha = m_beta_alpha_k

    return m_alpha_beta, m_beta_alpha

Here’s my scatter implementation of the above:

def iter_simp_min_sum_batch_scatter(m_alpha_beta, m_beta_alpha, weights, n_iter):

    for _ in range(n_iter):
        # Message passing
    

        beta_alpha_maxes, beta_alpha_indices = torch.topk(m_beta_alpha, 2, dim=1)
        print(beta_alpha_maxes[:, 0, :].requires_grad)
        m_alpha_beta_k = weights.permute(0, 2, 1) - beta_alpha_maxes[:, 0, :].unsqueeze(1)
        m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1), src=beta_alpha_maxes[:, 0, :].unsqueeze(1))
        m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1), src=beta_alpha_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2, 1)
        
        alpha_beta_maxes, alpha_beta_indices = torch.topk(m_alpha_beta, 2, dim=1)
        m_beta_alpha_k = weights - alpha_beta_maxes[:, 0, :].unsqueeze(1)
        m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1), src=alpha_beta_maxes[:, 0, :].unsqueeze(1))
        m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1), src=alpha_beta_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2, 1)
        
        m_alpha_beta = m_alpha_beta_k
        m_beta_alpha = m_beta_alpha_k

    return m_alpha_beta, m_beta_alpha

Finally, here’s how I test the above implementations:

# Scatter implementation
a = torch.nn.Parameter(input)

scatter_b_a, scatter_b_b = iter_simp_min_sum_batch_scatter(a, a.permute(0, 2, 1), a.clone().detach(), 3)

scatter_b_b.backward(gradient=torch.ones((scatter_b_b.size(0), scatter_b_b.size(1), scatter_b_b.size(2))))

# For-loop implementation
b = torch.nn.Parameter(input)

iter_b_a, iter_b_b = iter_simp_min_sum_batch(b, b.permute(0, 2, 1), b.clone().detach(), 3)

iter_b_b.backward(gradient=torch.ones((iter_b_b.size(0), iter_b_b.size(1), iter_b_b.size(2))))

Here’s the results:

a.grad

tensor([[[  0.,   0.,  -2.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,  -2.],
         [  0.,   0.,  -4.,   0.,   0.],
         [ -1.,   0.,   0.,   0., -16.]]])


b.grad

tensor([[[  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,   0.],
         [  0.,   0.,   0.,   0.,  -2.],
         [  0.,   0.,  -4.,   0.,   0.],
         [ -1.,   0.,  -2.,   0., -16.]]])

scatter_b_b - iter_b_b

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], grad_fn=<SubBackward0>)

As shown in scatter_b_b - iter_b_b, the solutions are the same. The difference is in the gradients which I don’t really understand where the difference is.

Thanks a lot for any help!

Hi,

What is the content of m_alpha_beta and m_beta_alpha?
One case that could lead to this is if they contain entries that are equal within them

Indeed, the topk and the indices selected can be any of the elements when there is a tie.

Similarly, in your for-loop version, the max will select one element “at random” if there are multiple elements equal to the max.

As you saw, this does not change the output result at all as these values are the same.
But since these values come from different entries, the gradients (that flow back to these selected entries only) can be different!

Hope that helps!

Hey! m_alpha_beta contains input, which is basically just a b x n x n parameter tensor, where b is the batch size. m_beta_alpha contains input.permute(0, 2, 1), which is basically the transpose of m_alpha_beta.

I tested it with the input above, which is

input = torch.Tensor([[[7,  8, 7, 20, 7],
                       [15,  2,  2,  4,  8],
                       [ 6, 17,  6,  7,  8],
                       [ 1,  6, 12,  6,  2],
                       [ 8,  1,  7,  2, 10]]]

I have also played around with the number of iterations and it always seems like the gradient is slightly off in terms of the positions.

As you can see in your example, your gradients don’t match in the third column which is exactly the oclumn that has duplicatevalues. With one version putting grad on the top 7 and the other one on the bottom 7 so looks like this is what happens (and is expected behavior!)

Oh, right! I just tested with a new tensor that has unique values in all rows and columns, and the gradients are the same! Thanks a lot. It appears as the max and topk functions were returning different indices because of the same values, just as you suspected. Here are some results for reference:

tens = torch.Tensor([[[ 2.,  8., 10., 20.,  7.],
                      [15.,  3.,  2.,  4.,  9.],
                      [10., 17.,  6.,  7.,  8.],
                      [ 1.,  7., 12.,  6.,  2.],
                      [ 8.,  1.,  7.,  2., 10.]]])

a.grad
tensor([[[ 0., -1., -1., -4.,  0.],
         [-4.,  0.,  0.,  0., -1.],
         [-1., -4.,  0., -1.,  0.],
         [ 0.,  0., -4.,  0.,  0.],
         [ 0.,  0.,  0.,  0., -4.]]])

b.grad
tensor([[[ 0., -1., -1., -4.,  0.],
         [-4.,  0.,  0.,  0., -1.],
         [-1., -4.,  0., -1.,  0.],
         [ 0.,  0., -4.,  0.,  0.],
         [ 0.,  0.,  0.,  0., -4.]]])

1 Like