Hello,
I have a (soft) adjacency tensor adj
of size B x N x M
with batch dimension B
. When I perform adj.max(dim=2)
I get a tuple of max values and max value indices, indicating for each row in N
dimension where it’s maximal in regards to M
dimension. Now I would like to use the returned max value indices to select entries from another tensor features
of size B x N x M x 10
, so that I get a tensor of size B x N x 10
which has kept only the entries of it’s dim=2
where adj
was maximal with respect to dim=2
.
I have been able to do this without the batch dimension:
import torch
adj = torch.tensor([[1,0,0],[0,1,0],[0,1,0],[0,0,1]]) # 4 x 3
features = torch.stack([torch.arange(12).reshape((4,3))]*10).permute((1,2,0)) # 4 x 3 x 10
m = adj.max(dim=1)[1] #
result = features[range(4),m]
print(result)
#result is:
#tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
# [ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
# [11, 11, 11, 11, 11, 11, 11, 11, 11, 11]])
Now using range(4)
there just doesn’t seem like the correct way to handle this to me. Also I have no idea how to do this with an additional batch dimension. Any ideas?