Index select with potentially multiple indices

I have a tensor Aof size B x 4 and another tensor B of size H x N. A has integer values and many are -1 like:

tensor([[ 2.2640e+03, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [ 3.5400e+02, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [ 6.2700e+02, 6.2900e+02, -1.0000e+00, -1.0000e+00]

I want a new tensor of size B x H with values from B according to the indices in A. For example, for the first entry, I have the tensor [2260, -1, -1, -1] and I want the tensor of size H, in position N (i.e., tensor2[:, 2260]). There might be entries from A where there are more than one number higher than -1, in which case I would like to get the respective tensors in the indices and then average them.

torch.index_select(b, 0, a[:, 0]) almost gives the correct answer. Just need a way to handle the case where more than 1 integer is given, like in the third row of the above example.

Thanks for any help you can provide.

Hi Afonso!

Because you want to average multiple entries of B, trying to use
index_select() won’t fit your use case.

Use A to pluck the desired values out of B with pytorch tensor indexing
and then compute the average “manually.”

Here is an illustration of such a scheme with a version of B that contains
smaller index values:

>>> import torch
>>>
>>> torch.__version__
'1.12.0'
>>>
>>> A = torch.tensor ([
...     [ 2, -1, -1, -1],
...     [ 3, -1, -1, -1],
...     [ 7,  6, -1, -1]
... ])
>>>
>>> B = torch.arange (50.0).view (5, 10) + 1
>>> B
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        [11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
        [21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
        [31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
        [41., 42., 43., 44., 45., 46., 47., 48., 49., 50.]])
>>>
>>> Amask = A > -1
>>> Aind = torch.max (A, torch.zeros (1, dtype = torch.long))
>>> result = ((B[:, Aind] * Amask).sum (dim = -1) / Amask.sum (dim = -1)).T
>>>
>>> result
tensor([[ 3.0000, 13.0000, 23.0000, 33.0000, 43.0000],
        [ 4.0000, 14.0000, 24.0000, 34.0000, 44.0000],
        [ 7.5000, 17.5000, 27.5000, 37.5000, 47.5000]])

Best.

K. Frank