Suppose I have the following embeddings
emb_user = torch.randn(64, 128, 256). From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:
idx = torch.multinomial(torch.ones(64, 128), 16) sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
What I also find curios is that the above multinomial would not work if the weight matrix (
torch.ones(64, 128)) exceeded more than 2 dimensions.