Hello,
I would like to ask for your help on making a vectorized version of the following simple operation.
I have a 3D tensor x
of shape (N, B, V)
and I would like to get the elements of x
at indices given by two (N, K)
tensors idx1
and idx2
as follows: y[i, j] = x[i, idx1[i,j], idx2[i,j]]
.
Using a for loop, this can be done using the following function:
def f(x, idx1, idx2):
"""Compute using for loop
x: N x B x V
idx1: N x K matrix where idx1[i, j] is between [0, B)
idx2: N x K matrix where idx2[i, j] is between [0, V)
Return:
y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
"""
# N x K
y = torch.zeros(idx1.shape)
N, K = idx1.shape
for i in range(N):
for j in range(K):
y[i, j] = x[i, idx1[i,j], idx2[i,j]]
return y
It seems that torch.gather
is not applicable here. If we do z = x[:,idx1,idx2]
, then it remains to do y[i,j] = z[i,i,j]
, which doesn’t seem to be any easier.
For your convenience the above function can be tested using the following:
def main():
N = 2
B = 3
V = 10
K = B*2
x = torch.randn(N, B, V)
idx1 = torch.randint(0, B, size=(N, K))
idx2 = torch.randint(0, V, size=(N, K))
y = f(x, idx1, idx2)
Thank you in advance for your help!