Mathematical equivalent computation with index select have different gradient

I was trying to calculate pairwise distance of a given a list of normalized vectors. Here’s how I initialize the input:

torch.manual_seed(2)
inputs = torch.nn.functional.normalize(torch.rand((10,2)))
# inputs1 will be passed through method1
inputs1 = inputs.clone().detach().requires_grad_(True)
# inputs2 will be passed through method2
inputs2 = inputs.clone().detach().requires_grad_(True)

labels = torch.tensor([1,2,3,4,5,6,7,8,9,0])

Method 1

n = inputs1.size(0)
d = inputs1.size(1)
x1 = inputs1.unsqueeze(1).expand(n, n, d)
y1 = inputs1.unsqueeze(0).expand(n, n, d)
z1 = torch.pow(x1 - y1, 2).sum(2)
z1 = z1 + torch.diag(torch.ones(inputs1.size(0)) * 1e-5)
z1 = torch.sqrt(z1)
z1 = torch.masked_select(z1[0], labels!=1)
method1 = torch.sum(z1)
method1.backward()
print(inputs1.grad)

Output of method 1

tensor([[ 2.8884, -1.6397],
[-0.5626, 0.8267],
[-0.5930, 0.8052],
[ 0.3746, -0.9272],
[ 0.3183, -0.9480],
[-0.6094, 0.7928],
[-0.7579, 0.6523],
[-0.6741, 0.7386],
[ 0.4057, -0.9140],
[-0.7900, 0.6131]])

Method 2

x2 = inputs2
y2 = inputs2.t()
cos_sim_matrix = torch.mm(x2, y2)
z2 = (torch.sum(x2 ** 2, dim=1) + torch.sum(y2 ** 2, dim=0)) - 2 * cos_sim_matrix
z2 = z2 + torch.diag(torch.ones(inputs2.size(0)) * 1e-5)
z2 = torch.sqrt(z2)
z2 = torch.masked_select(z2[0], labels!=1)
method2 = torch.sum(z2)
method2.backward()
print(inputs2.grad)

Output of method 2

tensor([[-29.0625, -21.4441],
[ 8.8515, 7.8375],
[ 4.1426, 4.9133],
[ 3.2485, -0.3052],
[ 2.4937, -0.7449],
[ 3.0176, 4.2116],
[ -0.1018, 2.1812],
[ 0.9631, 2.9099],
[ 3.9192, 0.0983],
[ -0.3604, 1.9822]])

The gradients of two methods would be the same if masked_select is commented out, the gardients are different only with masked_select or index_select. Did I make some mistakes in the calculation?

Hello,

Did you find some solution to fix the difference?:thinking:

Not yet, I still don’t know why adding select has such behavior. I tested select alone which seems fine…

Problem solved. In Method2, when do

z2 = (torch.sum(x2 ** 2, dim=1) + torch.sum(y2 ** 2, dim=0)) - 2 * cos_sim_matrix

The broadcast is triggered by -2 * cos_sim_matrix, which is wrong. The correct broadcast should be

z2 = (torch.sum(x2 ** 2, dim=1, keepdim=True) + torch.sum(y2 ** 2, dim=0, keepdim=True)) - 2 * cos_sim_matrix
1 Like