After checking the shapes of all the tensors, I realized that broadcasting was used for the query tensor in torch.einsum
. The query tensor is an additional input and it needs to be repeated to have a batch size that can be divided across the number of GPUs. So this problem has the same solution as mentioned here. Apologies for any confusion!