I’m trying to retrieve features from a known index tensor and a corresponding features tensor. The feature tensor G is of shape N*L*C*H*W while my index is of the shape V*H*W where every element’s value is between [0, L). I’m trying to construct a tensor Gr by selecting the corresponding tensor from G for each channel v present in my index.
Gr = torch.zeros(N, V, C, H, W, device=device)
hs = torch.arange(H)
ws = torch.arange(W)
for v in range(V):
Gr[:, v, :, hs, ws] = G[:, index[v, hs, ws], :, hs, ws].permute(1, 2, 0)
I think my code is correct, however, it’s too slow. Is there a simple way to vectorize this fragment that I’m missing?
Thanks!