Efficiently expand a tensor using higher-dimensional mask

Let us assume I’m working with a 2D tensor data of shape (4, 1) whose contents are given by [[1], [2], [3], [4]]. Now let expanded_data = torch.empty(4, 2, 1) and let indices be a 2D tensor of shape (4, 2) whose contents are given by [[2, 3], [0, 2], [0, 1], [1, 2]]. I would like to use indices as a mask to “expand” the original data tensor by populating expanded_data such that expanded_data[i] = data[indices[i]]. In this example, the contents of expanded_in should look like

[[[3], [4]], [[1], [3]], [[1], [2]], [[2], [3]]].

My naive implementation would be as follows:

for i in range(expanded_data.shape[0]):
    for idx, j in enumerate(indices[i]):
        expanded_data[i][idx] = data[j]

Is there a more efficient way to do this? A solution would preferably refrain from using for-loops and would remain entirely within the GPU.

Hi Gcortes!

You can do this without loops by using tensor indexing (with squeeze()
and unsqueeze() to get the shapes the way you want them):

>>> import torch
>>> torch.__version__
'1.12.0'
>>> data = torch.tensor ([[1], [2], [3], [4]])
>>> indices = torch.tensor ([[2, 3], [0, 2], [0, 1], [1, 2]])
>>> expanded_data = torch.empty(4, 2, 1)
>>> for i in range(expanded_data.shape[0]):
...     for idx, j in enumerate(indices[i]):
...         expanded_data[i][idx] = data[j]
...
>>> expanded_dataB = data.squeeze()[indices.unsqueeze (2)].float()
>>> torch.equal (expanded_data, expanded_dataB)
True

Best.

K. Frank

@KFrank I’m getting the following error from your solution:

IndexError: too many indices for tensor of dimension 1

Hi Gcortes!

Could you post a minimal but complete and runnable script that reproduces
your issue?

None of the tensors you mentioned in your original post are of dimension 1,
and the script whose output I posted works. You might try checking that the
tensors in your failed version have dimensions consistent with those in your
original post or in my example.

Best.

K. Frank