`DataParallel`

won’t work in this case, as this will copy the model to each GPU and split the data in dim0 sending each chunk to the corresponding GPU.

I haven’t tested this approach, but it might be possible to split the weigth matrix and push the splits onto different GPUs (similar to model sharding, but with a single layer).

Here is a dummy code (currently CPU only) to test for equal results:

```
# Reference embedding, which is too large on single GPU
emb = nn.Embedding(10, 300)
emb1 = nn.Embedding(5, 300)
emb2 = nn.Embedding(5, 300)
# Copy weigths so that we can compare both approaches
with torch.no_grad():
emb1.weight.copy_(emb.weight[:5])
emb2.weight.copy_(emb.weight[5:])
# Initialize some input randomly
x = torch.randint(0, 10, (3, 7))
# Split input and call corresponding embedding layer (on different GPUs)
emb1_idx = (x < 5).nonzero()
emb2_idx = (x >= 5).nonzero()
out1 = emb1(x[emb1_idx.split(1, 1)])
out2 = emb2(x[emb2_idx.split(1, 1)] - 5)
# Concatenate output to check for equal results
out_cat = torch.zeros(x.size(0), x.size(1), 300)
out_cat[emb1_idx.split(1, 1)] = out1
out_cat[emb2_idx.split(1, 1)] = out2
# Get reference output
out = emb(x)
# Compare outputs
print((out == out_cat).all())
# Call backward and compare gradients
out_cat.mean().backward()
out.mean().backward()
print((torch.cat((emb1.weight.grad, emb2.weight.grad), 0) == emb.weight.grad).all())
```

In your code, you would have to push the embedding “sublayers” as well as the corresponding input chunks to different devices.

Let me know, if I’m missing something obvious.