How to use max_indices given by torch.max(...)?

Hello,

I have a (soft) adjacency tensor adj of size B x N x M with batch dimension B. When I perform adj.max(dim=2) I get a tuple of max values and max value indices, indicating for each row in N dimension where it’s maximal in regards to M dimension. Now I would like to use the returned max value indices to select entries from another tensor features of size B x N x M x 10, so that I get a tensor of size B x N x 10 which has kept only the entries of it’s dim=2 where adj was maximal with respect to dim=2.

I have been able to do this without the batch dimension:

import torch

adj = torch.tensor([[1,0,0],[0,1,0],[0,1,0],[0,0,1]]) # 4 x 3
features = torch.stack([torch.arange(12).reshape((4,3))]*10).permute((1,2,0)) # 4 x 3 x 10

m = adj.max(dim=1)[1] #
result = features[range(4),m]
print(result)
#result is:
#tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
#            [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
#            [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  7],
#            [11, 11, 11, 11, 11, 11, 11, 11, 11, 11]])

Now using range(4) there just doesn’t seem like the correct way to handle this to me. Also I have no idea how to do this with an additional batch dimension. Any ideas?

Hi,

You can use these indices with the other_tensor.gather(dim, indices) function.
Note that you either need to use keepdim=True when you call max or unsqueeze the indices before giving them to gather.

1 Like

Thanks, it appears tensor.gather(...) is the right tool. Unfortunately, I still couldn’t get it to work. When I add an extra dimension at the end of adj and use keepdim=True, features.gather(1,m) only returns a tensor of size 4 x 1 x 1 while I need the whole 4 x 10 x 1 tensor. So gather somewhat ignores the last dimension of size 10 of features, it only picks the first element of that dimension.

Update: So I got it to work by using m.repeat(1,1,10) and using that as index tensor for the gather call. This still doesn’t seem like the most elegant solution to me. Now I gotta figure out how to do this with batches.

Update 2: So batching it is no problem, it’s literally the same with an additional dimension. I’m still interested whether there is a better way than calling repeat. :confused:

1 Like