Tensor parallelism simple Embedding Example

Hi,
I was doing through the examples of tensor parallel. After spending few hours, I still don’t get my head around.
Let’s assume that we have the following class and we have 2 GPUs.
How can we use ColwiseParallel tensor paralelisim to store the first half of the entity_embeddings and of the relation_embeddings in the first GPU and other halfs into second GPU.


class DistMult(nn.Module):
    def __init__(self):
        super().__init__(args)
        self.entity_embeddings = torch.nn.Embedding(135, 32)
        self.relation_embeddings = torch.nn.Embedding(46, 32)
    def forward(self,h,r,t):
    h=self.entity_embeddings(h)
    r=self.relation_embeddings(r)
    t=self.entity_embeddings(t)
    return ((h * r) * t).sum(dim=1)

Constantly, the following RuntimeError occurs
RuntimeError: Function EmbeddingBackward0 returned an invalid gradient at index 0 - got [135, 32] but expected shape compatible with [46, 32]