I would like to ask for your help on making a vectorized version of the following simple operation.
I have a 3D tensor
x of shape
(N, B, V) and I would like to get the elements of
x at indices given by two
(N, K) tensors
idx2 as follows:
y[i, j] = x[i, idx1[i,j], idx2[i,j]].
Using a for loop, this can be done using the following function:
def f(x, idx1, idx2): """Compute using for loop x: N x B x V idx1: N x K matrix where idx1[i, j] is between [0, B) idx2: N x K matrix where idx2[i, j] is between [0, V) Return: y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]] """ # N x K y = torch.zeros(idx1.shape) N, K = idx1.shape for i in range(N): for j in range(K): y[i, j] = x[i, idx1[i,j], idx2[i,j]] return y
It seems that
torch.gather is not applicable here. If we do
z = x[:,idx1,idx2], then it remains to do
y[i,j] = z[i,i,j], which doesn’t seem to be any easier.
For your convenience the above function can be tested using the following:
def main(): N = 2 B = 3 V = 10 K = B*2 x = torch.randn(N, B, V) idx1 = torch.randint(0, B, size=(N, K)) idx2 = torch.randint(0, V, size=(N, K)) y = f(x, idx1, idx2)
Thank you in advance for your help!