Shape of output becomes wrong with nn.DataParallel

My model takes two batched sequences (along with their attention masks) and is supposed to return a tensor of shape (batch_size, batch_size) whose (i,j)-element should be cosine similarity of the i-th element of the first (b_que_iis[j]) and the j-th element of the second (b_art_iis[j])

I ran this with batch_size 16 on a machine with 4 GPUs, using nn.DataParallel, so the batch_size becomes 4 on each GPU.
Each GPU returns a 4 by 4 tensor as a result.
However, the final output of the model
sims = model(b_que_iis, b_que_ams, b_art_iis, b_art_ams)
return a tensor of shape (16,4).

How should I solve this problem and get the desired output of (16, 16) tensor?
I would greatly appreciate any comment.

Thank you!

class QA_match(nn.Module):
    def __init__(self, bert_model):
        super(QA_match, self).__init__()
        self.bert = bert_model
    def forward(self, b_que_iis, b_que_ams, b_art_iis, b_art_ams):
        query_cls = self.bert(input_ids=b_art_iis, 
                             attention_mask=b_art_ams).last_hidden_state[:,0,:]
        article_cls = self.bert(input_ids=b_que_iis, 
                             attention_mask=b_que_ams).last_hidden_state[:,0,:]
        query_cls = nn.functional.normalize(query_cls, dim=1)
        article_cls = nn.functional.normalize(article_cls, dim=1)

        sim = torch.einsum('ij, kj -> ik', query_cls, article_cls) #take the dot product
        return sim

This is expected, since the result will be concatenated in dim0 again.
Each device returns 4x4=16 values and thus the result of 4 devices is 4*4*4=16*4=64.

Assuming the additional values are created as function between the samples in a batch, you could try to calculate them after nn.DataParallel returned the output.

Ah I see. So I should do the einsum outside the model.
Thank you very much!