Given a 2d tensor a.shape = (B, D), I want to randomly sample k elements from this tensor, so i write this code:

```
a = torch.randn(B, D)
p = torch.ones((B, D))
index = p.multinomial(num_samples=k, replacement=True)
```

The shape of index is (B, k) and the shape of a = (B, D), how can i index elements from a?

i.e. out[i, j] = a[i, index[i, j]]