Efficiently retrieving features using a known index tensor

I’m trying to retrieve features from a known index tensor and a corresponding features tensor. The feature tensor G is of shape N*L*C*H*W while my index is of the shape V*H*W where every element’s value is between [0, L). I’m trying to construct a tensor Gr by selecting the corresponding tensor from G for each channel v present in my index.

Gr = torch.zeros(N, V, C, H, W, device=device)
hs = torch.arange(H)
ws = torch.arange(W)
for v in range(V):
    Gr[:, v, :, hs, ws] = G[:, index[v, hs, ws], :, hs, ws].permute(1, 2, 0)

I think my code is correct, however, it’s too slow. Is there a simple way to vectorize this fragment that I’m missing?

Thanks!

I’m unsure how your code is working as a few things are undefined and based on your description of the index I assume you are using something like this:

N, V, C, H, W = 2, 3, 4, 5, 6

G_kvp = torch.randn(N, V, C, H, W)
out = torch.zeros_like(G_kvp)
index = torch.randint(0, V, (V, H, W))
hs = torch.arange(H)
ws = torch.arange(W)
for v in range(V):
    out[:, v, :, hs, ws] = G_kvp[:, index[v, hs, ws], :, hs, ws].permute(1, 2, 0)

which gives an error in index[v, hs, ws]. Could you post the slow reference code, please?

Thanks for your response! And sorry, I changed G_kvp to Gr in the question but forgot to fix it in the code. Here’s a minimal example that works for me.

import torch
import random

N, L, V, C, H, W = 1, 64, 1024, 3, 128, 128
G = torch.zeros(N, L, C, H, W)
index = torch.zeros(V, H, W)
index += random.randint(0, L) # just to keep it simple
index = index.to(dtype=torch.long)

G_kvp = torch.zeros(N, V, C, H, W)
hs = torch.arange(H)
ws = torch.arange(W)
for v in range(V):
     G_kvp[:, v, :, hs, ws] = G[:, index[v, hs, ws], :, hs, ws].permute(1, 2, 0)

print(G_kvp.shape)

I looked at code from DGCNN where there is a similar indexing operation for implementing EdgeConv. The basic idea is to convert the feature and index tensors into 1D tensors (add offsets w.r.t row-major order. Pytorch Tensors follow that convention.) and then employ indexing as one normally would.

N, L, V = 1, 32, 128
C, H, W = 3, 64, 64

device = "cuda"
G = torch.zeros(N, L, C, H, W, device=device)

index = torch.randint(0, L, (V, H, W), device=device)
index = index.to(dtype=torch.long)

# extremely naive
G_kvp = torch.zeros(N, V, C, H, W, device=device)
hs = torch.arange(H)
ws = torch.arange(W)
for v in range(V):
     for h in range(H):
          for w in range(W):
               G_kvp[:, v, :, h, w] = G[:, index[v, h, w], :, h, w]


base = torch.arange(0, H*W, device=device)
index = index.reshape(V, -1) + base
G = G.permute(0, 2, 1, 3, 4).reshape(N, C, -1) # N * C * (L.H.W)
# # Get the indexed values from G
G_kvp_new = G[:, :, index]
# # reshape it to the expected feature tensor
G_kvp_new = G_kvp_new.view(N, V, C, H, W)

print(torch.all(G_kvp == G_kvp_new))

Though there is probably a cleaner way to do it via einsum/einops.