Why does nn.Embedding expects inputs to be on the same device as the weights?

I am implementing an algorithm that has to send many small queries to the nn.Embedding while training. Since my embedding is on CUDA, I have to move my indices to CUDA every time I want to do so, which is extremely expensive. So I have the following questions:

  1. Why the inputs have to be on the same device as the weight when nn.Embedding is basically just a lookup table?
  2. I am thinking of a workaround and that is to reimplement the embedding layer with raw tensors and perform manual indexing, for example:
embs = torch.rand((5, 5), device="cuda", requires_grad=True)
optimizer = torch.optim.Adam([embs], 0.005)
loss = ((embs[[0, 2, 4]] - torch.ones_like(embs[[0, 2, 4]]))**2).sum()
loss.backward()
optimizer.step()

is this way safe and optimized? Does manual indexing is just the same as moving indices to CUDA, then retrieve the values?
3. Any other suggestions to solve this problem?

Thanks a lot.

  1. Since the operation is executed on the GPU, the inputs as well as other tensors must be loaded onto the device before the kernel can be launched.

  2. The indices will also be moved to the device, so there shouldn’t be a difference.