Selecting n indices without replacement from dimension x

Suppose I have the following embeddings emb_user = torch.randn(64, 128, 256). From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:

idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

What I also find curios is that the above multinomial would not work if the weight matrix (torch.ones(64, 128)) exceeded more than 2 dimensions.