Index 4d tensor via 2 2d tensors containing indices of specific dimensions

Good day everyone!

I’ve seen some posts similar to what I’m looking for, but none of the solutions seem to be exactly what I want to do.

I have a 4-dimensional tensor, let’s call it T, of shape n x n x s x b. Then I have 2 tensors, i and j, each of shape s x b. Tensors i and j contain indexes for the first and second dimensions of tensor T, respectively (values between 0 and n-1).

I was trying to do something like this to index tensor T with i and j (I believe this can be done:

T[i, j, :, :] = some_data

where some_data is a tensor of shape s x b. This doesn’t really work however, since it seems to be writing values multiple times. Below are some examples of what is happening.

Example 1 (all good here):

>>> import torch
>>> i = torch.ones([5,18]).long()
>>> j = torch.ones([5,18]).long()*2
>>> T = torch.zeros([3,3,5,18])
>>> some_data = torch.ones([5,18])*4
>>> T[i,j,:,:] = some_data
>>> T[:,:,0,0]
tensor([[0., 0., 0.],
        [0., 0., 4.],
        [0., 0., 0.]])
>>> torch.sum(T)
tensor(360.)
>>> 4*5*18 # what the data should sum to
360

Example 2 (with randomized indexes - here the results are not the same)

>>> import torch
>>> i = torch.randint(low=1,high=3,size=[5,18])
>>> j = torch.randint(low=1,high=3,size=[5,18])
>>> some_data = torch.ones([5,18])*4
>>> T = torch.zeros([3,3,5,18])
>>> T[i,j,:,:] = some_data
>>> i[0,0], j[0,0]
(tensor(2), tensor(1))
>>> T[:,:,0,0] # should only have data on [2,1] position, but has on others
tensor([[0., 0., 0.],
        [0., 4., 4.],
        [0., 4., 4.]])
>>> torch.sum(T)
tensor(1440.)
>>> 4*5*18 # what the data should sum to (or what I wanted it to sum to)
360

Example 3 (with real data, and indexes are not always the same, i.e. not just tensor of ones/zeros; here problems occur):

>>> import torch
>>> i = torch.tensor([[1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1],
...         [1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1],
...         [1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1],
...         [1, 1, 2, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2],
...         [2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2]])
>>> j = torch.tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
...         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
...         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
...         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
...         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
>>> # actually, the data here could also be the same as example1, the same situation also happens
>>> some_data = torch.tensor([[0.0642, 0.1811, 0.1436, 0.0756, 0.1513, 0.3136, 0.3522, 0.0376, 0.0286,
...          0.0413, 0.0604, 0.3314, 0.4650, 0.0230, 0.1622, 0.0016, 0.2878, 0.0514],
...         [0.0541, 0.1553, 0.1813, 0.0651, 0.1507, 0.3136, 0.3522, 0.0424, 0.0178,
...          0.0307, 0.0443, 0.3947, 0.4650, 0.0230, 0.1316, 0.0132, 0.2439, 0.0253],
...         [0.0298, 0.1272, 0.2162, 0.0475, 0.1579, 0.3136, 0.3522, 0.0435, 0.0122,
...          0.0184, 0.0247, 0.4607, 0.4650, 0.0230, 0.1080, 0.0281, 0.2109, 0.0000],
...         [0.0036, 0.0946, 0.2541, 0.0287, 0.1731, 0.3136, 0.3522, 0.0430, 0.0028,
...          0.0058, 0.0060, 0.5367, 0.4650, 0.0230, 0.0853, 0.0463, 0.1868, 0.3064],
...         [0.4510, 0.0626, 0.2908, 0.0071, 0.1859, 0.3136, 0.3522, 0.0397, 0.1229,
...          0.2560, 0.3906, 0.6065, 0.4650, 0.0184, 0.0598, 0.0674, 0.1602, 0.2811]])
>>> i.shape, j.shape, some_data.shape
(torch.Size([5, 18]), torch.Size([5, 18]), torch.Size([5, 18]))
>>> T = torch.zeros([3,3,5,18])
>>> T[i,j,:,:] = some_data
>>> T[:,:,0,0] # should only have value for[1,2], but has for [2,2] also
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0642],
        [0.0000, 0.0000, 0.0642]])

I don’t think this is a bug, but rather the behaviour for indexing is not what I think it is. But in case it might be a bug, I’m using pytorch 1.7.1 and numpy 1.19.2 (I believe the same behaviour occurs in numpy as well, but haven’t tested thourougly)
What is the proper way to perform such indexing? In the mean time I’m just making due with two for loops to use one index from i,j at a time. With that, the results are correct

Thank you!