How to index a 2d tensor by another 2d tensor?

Given a 2d tensor a.shape = (B, D), I want to randomly sample k elements from this tensor, so i write this code:

a = torch.randn(B, D)
p = torch.ones((B, D))
index = p.multinomial(num_samples=k, replacement=True)

The shape of index is (B, k) and the shape of a = (B, D), how can i index elements from a?
i.e. out[i, j] = a[i, index[i, j]]

Hi Yu!

torch.gather (a, 1, index) should do what you want.

Best.

K. Frank

1 Like