Hi all,
I met a problem with scatter_
.
Here is the code(it was posted here: Differentiable argmax):
class ArgMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
idx = torch.argmax(input, 1)
output = torch.zeros_like(input)
output.scatter_(1, idx, 1)
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
if __name__ == '__main__':
a = torch.rand((3, 2, 4, 2), requires_grad=True)
print(a)
argmax = ArgMax()
b = argmax.apply(a)
b.backward()
But it seems scatter_
doesn’t work here, and I got the error:
RuntimeError: invalid argument 3: Index tensor must either be empty or have same dimensions as output tensor at /Users/distiller/project/conda/conda-bld/pytorch_1570710797334/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:133
Is there anything wrong in the usage of scatter_
here?
Thanks in advance!