Problem with scatter_

Hi all,

I met a problem with scatter_.
Here is the code(it was posted here: Differentiable argmax):

class ArgMax(torch.autograd.Function):
    def forward(ctx, input):
        idx = torch.argmax(input, 1)
        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1)
        return output

    def backward(ctx, grad_output):
        return grad_output

if __name__ == '__main__':
    a = torch.rand((3, 2, 4, 2), requires_grad=True)
    argmax = ArgMax()
    b = argmax.apply(a)

But it seems scatter_ doesn’t work here, and I got the error:

RuntimeError: invalid argument 3: Index tensor must either be empty or have same dimensions as output tensor at /Users/distiller/project/conda/conda-bld/pytorch_1570710797334/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:133

Is there anything wrong in the usage of scatter_ here?

Thanks in advance!

I think you want this to read

idx = torch.argmax(input, 1, keepdim=True)

Best regards


Thanks so much, Thomas! This fixed the problem.