I am implementing an algorithm that has to send many small queries to the nn.Embedding while training. Since my embedding is on CUDA, I have to move my indices to CUDA every time I want to do so, which is extremely expensive. So I have the following questions:
- Why the inputs have to be on the same device as the weight when nn.Embedding is basically just a lookup table?
- I am thinking of a workaround and that is to reimplement the embedding layer with raw tensors and perform manual indexing, for example:
embs = torch.rand((5, 5), device="cuda", requires_grad=True) optimizer = torch.optim.Adam([embs], 0.005) loss = ((embs[[0, 2, 4]] - torch.ones_like(embs[[0, 2, 4]]))**2).sum() loss.backward() optimizer.step()
is this way safe and optimized? Does manual indexing is just the same as moving indices to CUDA, then retrieve the values?
3. Any other suggestions to solve this problem?
Thanks a lot.