DataParallel won’t work in this case, as this will copy the model to each GPU and split the data in dim0 sending each chunk to the corresponding GPU.
I haven’t tested this approach, but it might be possible to split the weigth matrix and push the splits onto different GPUs (similar to model sharding, but with a single layer).
Here is a dummy code (currently CPU only) to test for equal results:
# Reference embedding, which is too large on single GPU
emb = nn.Embedding(10, 300)
emb1 = nn.Embedding(5, 300)
emb2 = nn.Embedding(5, 300)
# Copy weigths so that we can compare both approaches
# Initialize some input randomly
x = torch.randint(0, 10, (3, 7))
# Split input and call corresponding embedding layer (on different GPUs)
emb1_idx = (x < 5).nonzero()
emb2_idx = (x >= 5).nonzero()
out1 = emb1(x[emb1_idx.split(1, 1)])
out2 = emb2(x[emb2_idx.split(1, 1)] - 5)
# Concatenate output to check for equal results
out_cat = torch.zeros(x.size(0), x.size(1), 300)
out_cat[emb1_idx.split(1, 1)] = out1
out_cat[emb2_idx.split(1, 1)] = out2
# Get reference output
out = emb(x)
# Compare outputs
print((out == out_cat).all())
# Call backward and compare gradients
print((torch.cat((emb1.weight.grad, emb2.weight.grad), 0) == emb.weight.grad).all())
In your code, you would have to push the embedding “sublayers” as well as the corresponding input chunks to different devices.
Let me know, if I’m missing something obvious.