I have 2 tensor a, b. a.shape = [2, n], b.shape = [n]. The element in b[i] is 0/1, which means we choose a[0][i] or a[1][i]. I want to get tensor c, c.shape = [n], c[i] = a[0][i] or a[1][i], depending on b[i] = 0 or 1.
So how can I get c without for loop?
I think you can use torch.where here:
torch.where
import time import torch n = 2048 a = torch.randn(2, n) b = torch.randn(n) > 0 t1 = time.time() out1 = torch.empty(n) for i in range(n): if b[i]: out1[i] = a[1][i] else: out1[i] = a[0][i] t2 = time.time() out2 = torch.where(b, a[1], a[0]) t3 = time.time() print(torch.allclose(out1, out2)) print(t2-t1, t3-t2)
True 0.0062215328216552734 3.409385681152344e-05