The following is my code:
index = torch.tensor() A = torch.tensor([ [0.3, 0.15, 0.1, 0.1, 0.05]], requires_grad=True) B = torch.gather(A, 1, index.view(-1,1)) loss = (A - B.expand_as(A)) + 0.2 print('Loss = \n', loss) mask = torch.ones_like(loss) mask = mask.scatter_(1, index.view(-1,1), 0) loss = loss*mask print('Loss After mask = \n', loss) loss = loss.sum() loss.backward() print('Input and grad:') print(A) print(A.grad)
The output is:
I multiply the loss with a mask where some values are zero. In that case I expect the gradient for that position also should be zero. But I dont understand when I am getting -4 as the gradient at that position.