Tensor index operation

I have 2 tensor a, b. a.shape = [2, n], b.shape = [n]. The element in b[i] is 0/1, which means we choose a[0][i] or a[1][i].
I want to get tensor c, c.shape = [n], c[i] = a[0][i] or a[1][i], depending on b[i] = 0 or 1.

So how can I get c without for loop?

I think you can use torch.where here:

import time
import torch

n = 2048
a = torch.randn(2, n)
b = torch.randn(n) > 0
t1 = time.time()
out1 = torch.empty(n)
for i in range(n):
    if b[i]:
        out1[i] = a[1][i]
    else:
        out1[i] = a[0][i]
t2 = time.time()
out2 = torch.where(b, a[1], a[0])
t3 = time.time()
print(torch.allclose(out1, out2))
print(t2-t1, t3-t2)
True
0.0062215328216552734 3.409385681152344e-05
1 Like