Suppose I have the following embeddings `emb_user = torch.randn(64, 128, 256)`

. From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:

```
idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
```

What I also find curios is that the above multinomial would not work if the weight matrix (`torch.ones(64, 128)`

) exceeded more than 2 dimensions.