Misunderstanding scatter_ operation

I’m having a misunderstanding with the scatter operation. Here’s the code that returns the error:

import torch

input_seq = torch.LongTensor([[5,1,7,8,3],[8,3,6,3,7]])
attention_target = torch.LongTensor([[3,4],[1,2]])

true_indices = torch.zeros(attention_target.size())
true_indices.scatter_(1,attention_target.long(),input_seq.float())

Essentially I want the output to be:

[[8,3],[3,6]]

But I’m getting the an error stating:

Invalid index in scatter at /opt/conda/conda-bld/pytorch_1503965122592/work/torch/lib/TH/generic/THTensorMath.c:470

What am I doing wrong?

I don’t think you need scatter for your purpose. Try indexing instead:

input_seq[[[[0, 0], [1, 1]]], attention_target]