Outputs from nn.DataParallel are not concatenated

After checking the shapes of all the tensors, I realized that broadcasting was used for the query tensor in torch.einsum. The query tensor is an additional input and it needs to be repeated to have a batch size that can be divided across the number of GPUs. So this problem has the same solution as mentioned here. Apologies for any confusion!