# Similar to torch.gather over two dimensions

Hello,

I would like to ask for your help on making a vectorized version of the following simple operation.

I have a 3D tensor `x` of shape `(N, B, V)` and I would like to get the elements of `x` at indices given by two `(N, K)` tensors `idx1` and `idx2` as follows: `y[i, j] = x[i, idx1[i,j], idx2[i,j]]`.

Using a for loop, this can be done using the following function:

``````def f(x, idx1, idx2):
"""Compute using for loop
x: N x B x V
idx1: N x K matrix where idx1[i, j] is between [0, B)
idx2: N x K matrix where idx2[i, j] is between [0, V)
Return:
y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
"""
# N x K
y = torch.zeros(idx1.shape)
N, K = idx1.shape
for i in range(N):
for j in range(K):
y[i, j] = x[i, idx1[i,j], idx2[i,j]]
return y
``````

It seems that `torch.gather` is not applicable here. If we do `z = x[:,idx1,idx2]`, then it remains to do `y[i,j] = z[i,i,j]`, which doesn’t seem to be any easier.

For your convenience the above function can be tested using the following:

``````def main():
N = 2
B = 3
V = 10
K = B*2
x = torch.randn(N, B, V)
idx1 = torch.randint(0, B, size=(N, K))
idx2 = torch.randint(0, V, size=(N, K))
y = f(x, idx1, idx2)
``````

I have finally found a solution but I’m not sure if it’s the most efficient one:

``````y = torch.einsum('iij->ij', x[:,idx1,idx2])
``````

Hi,

Otherwise, you can linearize the two dimensions into one and then use gather like this:

``````import torch

def f(x, idx1, idx2):
"""Compute using for loop
x: N x B x V
idx1: N x K matrix where idx1[i, j] is between [0, B)
idx2: N x K matrix where idx2[i, j] is between [0, V)
Return:
y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
"""
# N x K
y = torch.zeros(idx1.shape)
N, K = idx1.shape
for i in range(N):
for j in range(K):
y[i, j] = x[i, idx1[i,j], idx2[i,j]]
return y

def f2(x, idx1, idx2):
# Linearize the last two dims and index in a contiguous x
x = x.contiguous()

lin_idx = idx2 + x.size(-1) * idx1
x = x.view(-1, x.size(1) * x.size(2))

return x.gather(-1, lin_idx)

N = 2
B = 3
V = 10
K = B*2
x = torch.randn(N, B, V)
idx1 = torch.randint(0, B, size=(N, K))
idx2 = torch.randint(0, V, size=(N, K))
y = f(x, idx1, idx2)
y2 = f2(x, idx1, idx2)

print((y-y2).abs().max())
``````
2 Likes

Wow, that was fast! Thank you so much, @albanD! Let me try your solution to see how fast it is.

1 Like

@albanD No surprise, your version is super fast! Using the following benchmark code, I obtained on my Mac (CPU):

``````10 trials took:
f1: 2.2150492668151855
f2: 0.36966371536254883
f3: 0.004826068878173828
``````

Thank you so much again!

``````import time
import torch

def f1(x, idx1, idx2):
"""Using for loop
x: N x B x V
idx1: N x K matrix where idx1[i, j] is between [0, B)
idx2: N x K matrix where idx2[i, j] is between [0, V)
Return:
y: N x K matrix where y[i, j] = x[i, idx1[i,j], idx2[i,j]]
"""
# N x K
y = torch.zeros(idx1.shape)
N, K = idx1.shape
for i in range(N):
for j in range(K):
y[i, j] = x[i, idx1[i,j], idx2[i,j]]
return y

def f2(x, idx1, idx2):
"""Same as above but more efficient
"""
# N x K

def f3(x, idx1, idx2):
"""Same as above, but even faster
https://discuss.pytorch.org/t/similar-to-torch-gather-over-two-dimensions/118827/3?u=f10w
"""
# Linearize the last two dims and index in a contiguous x
x = x.contiguous()

lin_idx = idx2 + x.size(-1) * idx1
x = x.view(-1, x.size(1) * x.size(2))

return x.gather(-1, lin_idx)

def main():
N = 100
B = 100
V = 10000
K = B*2

repeat = 10
t1 = 0
t2 = 0
t3 = 0

for _ in range(repeat):
x = torch.randn(N, B, V)
idx1 = torch.randint(0, B, size=(N, K))
idx2 = torch.randint(0, V, size=(N, K))

start = time.time()
y1 = f1(x, idx1, idx2)
t1 += time.time() - start

start = time.time()
y2 = f2(x, idx1, idx2)
t2 += time.time() - start

start = time.time()
y3 = f3(x, idx1, idx2)
t3 += time.time() - start

assert (y1-y2).abs().max() < 1e-10
assert (y1-y3).abs().max() < 1e-10

print(f'{repeat} trials took:\nf1: {t1}\nf2: {t2}\nf3: {t3}')

if __name__ == "__main__":
main()
``````
2 Likes