How to put a module in different GPUs

I’m currently training a word2vec model with a large vocabulary (50 million). When I use a embedding layer with 64 dimension, the volume of the embedding weight (64x50milliton) will exceed the memory of a single GPU, so I wanna put the weight on different GPUs.

I’ve already tried DataParallel but it still doesn’t work.

        self.in_embed = nn.Embedding(self.num_classes, self.embed_size, sparse=True)
        self.in_embed.weight = Parameter(t.FloatTensor(self.num_classes, self.embed_size).uniform_(-1, 1))
        self.in_embed = DataParallel(self.in_embed,device_ids=[0,1,2])
        self.in_embed.cuda()

DataParallel won’t work in this case, as this will copy the model to each GPU and split the data in dim0 sending each chunk to the corresponding GPU.
I haven’t tested this approach, but it might be possible to split the weigth matrix and push the splits onto different GPUs (similar to model sharding, but with a single layer).
Here is a dummy code (currently CPU only) to test for equal results:

# Reference embedding, which is too large on single GPU
emb = nn.Embedding(10, 300)
emb1 = nn.Embedding(5, 300)
emb2 = nn.Embedding(5, 300)

# Copy weigths so that we can compare both approaches
with torch.no_grad():
    emb1.weight.copy_(emb.weight[:5])
    emb2.weight.copy_(emb.weight[5:])

# Initialize some input randomly
x = torch.randint(0, 10, (3, 7))

# Split input and call corresponding embedding layer (on different GPUs)
emb1_idx = (x < 5).nonzero()
emb2_idx = (x >= 5).nonzero()

out1 = emb1(x[emb1_idx.split(1, 1)])
out2 = emb2(x[emb2_idx.split(1, 1)] - 5)

# Concatenate output to check for equal results
out_cat = torch.zeros(x.size(0), x.size(1), 300)
out_cat[emb1_idx.split(1, 1)] = out1
out_cat[emb2_idx.split(1, 1)] = out2

# Get reference output
out = emb(x)

# Compare outputs
print((out == out_cat).all())

# Call backward and compare gradients
out_cat.mean().backward()
out.mean().backward()

print((torch.cat((emb1.weight.grad, emb2.weight.grad), 0) == emb.weight.grad).all())

In your code, you would have to push the embedding “sublayers” as well as the corresponding input chunks to different devices.

Let me know, if I’m missing something obvious.

1 Like