I am implementing an algorithm that has to send many small queries to the nn.Embedding while training. Since my embedding is on CUDA, I have to move my indices to CUDA every time I want to do so, which is extremely expensive. So I have the following questions:
- Why the inputs have to be on the same device as the weight when nn.Embedding is basically just a lookup table?
- I am thinking of a workaround and that is to reimplement the embedding layer with raw tensors and perform manual indexing, for example:
embs = torch.rand((5, 5), device="cuda", requires_grad=True)
optimizer = torch.optim.Adam([embs], 0.005)
loss = ((embs[[0, 2, 4]] - torch.ones_like(embs[[0, 2, 4]]))**2).sum()
loss.backward()
optimizer.step()
is this way safe and optimized? Does manual indexing is just the same as moving indices to CUDA, then retrieve the values?
3. Any other suggestions to solve this problem?
Thanks a lot.