Suppose I have the following embeddings emb_user = torch.randn(64, 128, 256)
. From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:
idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
What I also find curios is that the above multinomial would not work if the weight matrix (torch.ones(64, 128)
) exceeded more than 2 dimensions.