How to select element in pytorch like numpy?

In numpy, we can use

a = np.array([[1,2],[3,4]])
a[np.arange(2), np.arange(2)]

to select one element at selected rows by index.

I wonder how to do this in pytorch’s variable?

you can use gather,

In [17]: a = torch.Tensor([[1,2],[3,4]])

In [18]: idx = torch.LongTensor([[0], [1]])

In [19]: torch.gather(a, 1, idx)

[torch.FloatTensor of size 2x1]

Hi, I’m wondering how could we use gather to select and write inplace. For example:
in numpy, we could write

a = np.array([[1,2],[3,4]])
a[np.arange(2), np.arange(2)] = np.ones(2)

However, when I tried in pytorch with gather, I got

In [12]: torch.gather(a, 1, idx) = torch.ones(2,1)
  File "<ipython-input-12-9ba3ceddfc82>", line 1
    torch.gather(a, 1, idx) = torch.ones(2,1)
SyntaxError: can't assign to function call

It would be great helpful for my project if I could do so…
Is there any suggestions?
Best regards

You could simply do the following.

mat = torch.randn(4, 4)
mat_t = torch.t(mat)
idx = [0, 1]
mat_t[idx] = torch.ones(mat_t.size()[1])
mat = torch.t(mat_t)

Nice, Thanks a lot.
However, it is not what I meant.
It is not just to overwrite some lines, but overwrite some elements of some lines in a more fine-grained way. For example:

In [19]: mat = torch.arange(16).view(4,4)

In [20]: idx = torch.LongTensor([[0, 2, 3, 2]])

In [21]: mat
tensor([[  0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.],
        [ 12.,  13.,  14.,  15.]])

In [22]: torch.gather(mat, 0, idx)
Out[22]: tensor([[  0.,   9.,  14.,  11.]])

and I just want to replace the gathered elements with some other things.

It looks like I am doing scatter in a gather way. May be I should check some scatter method?
However, it would be nice if pytorch could support it directly with gather method, just like in numpy.

OK, torch.Tensor.scatter_ method could solve my problem.
Really nice

1 Like