In numpy, we can use
a = np.array([[1,2],[3,4]])
a[np.arange(2), np.arange(2)]
to select one element at selected rows by index.
I wonder how to do this in pytorch’s variable?
In numpy, we can use
a = np.array([[1,2],[3,4]])
a[np.arange(2), np.arange(2)]
to select one element at selected rows by index.
I wonder how to do this in pytorch’s variable?
you can use gather
,
In [17]: a = torch.Tensor([[1,2],[3,4]])
In [18]: idx = torch.LongTensor([[0], [1]])
In [19]: torch.gather(a, 1, idx)
Out[19]:
1
4
[torch.FloatTensor of size 2x1]
Hi, I’m wondering how could we use gather to select and write inplace. For example:
in numpy, we could write
a = np.array([[1,2],[3,4]])
a[np.arange(2), np.arange(2)] = np.ones(2)
However, when I tried in pytorch with gather, I got
In [12]: torch.gather(a, 1, idx) = torch.ones(2,1)
File "<ipython-input-12-9ba3ceddfc82>", line 1
torch.gather(a, 1, idx) = torch.ones(2,1)
^
SyntaxError: can't assign to function call
It would be great helpful for my project if I could do so…
Is there any suggestions?
Best regards
You could simply do the following.
mat = torch.randn(4, 4)
mat_t = torch.t(mat)
idx = [0, 1]
mat_t[idx] = torch.ones(mat_t.size()[1])
mat = torch.t(mat_t)
Nice, Thanks a lot.
However, it is not what I meant.
It is not just to overwrite some lines, but overwrite some elements of some lines in a more fine-grained way. For example:
In [19]: mat = torch.arange(16).view(4,4)
In [20]: idx = torch.LongTensor([[0, 2, 3, 2]])
In [21]: mat
Out[21]:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.]])
In [22]: torch.gather(mat, 0, idx)
Out[22]: tensor([[ 0., 9., 14., 11.]])
and I just want to replace the gathered elements with some other things.
It looks like I am doing scatter in a gather way. May be I should check some scatter method?
However, it would be nice if pytorch could support it directly with gather method, just like in numpy.
OK, torch.Tensor.scatter_ method could solve my problem.
Really nice
Thanks