The Problem of Dimension Change for Additional Parameters in a Model due to DataParallel

os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2'

class model1(nn.Module):
    def __init__(self, const, dim_in,feat_dim):
        super(model1, self).__init__()
        # encoder + head
        self.encoder = resnet50
        self.head = nn.Linear(dim_in, feat_dim)

        # additional learnable parameters
        self.additional_param = nn.Parameter(torch.randn(const, feat_dim)) # shape [const, feat_dim]
        nn.init.uniform_(self.additional_param, a=0.0, b=1.0)

    def forward(self, x):
        additional_param_out = self.additional_param

        x = self.encoder(x)
        x = self.head(x)
        x = F.normalize(x, dim=1)

        return x, additional_param_out 

class loss1(nn.Module):
    def __init__(self):
        super(loss1, self).__init__()
        pass
    def forward(self, results, additional_param):
        loss_val = (results*additional_param).sum()

        return loss_val

net = model1(const, dim_in, feat_dim)
net = torch.nn.DataParallel(net)
loss = loss1()

for batch_idx, (img, target) in enumerate(batch_iterator):

    result1, result2 = net(img) 
    # result1.shape = [batch size, feat_dim]
    # result2.shape = ?
    loss_val = loss(result1, result2 )

The code mentioned above is a somewhat simplified version of the original code. The key point is that, in addition to the encoder and projection head in the model, there is also a randomly initialized learnable parameter (additional_param), and this parameter is used in the loss calculation along with the encoded features.

When using a single GPU without DataParallel, there is no issue. However, when using multiple GPUs with DataParallel as in the above code, the shape of result2 changes, causing problems.

For example, when const = batch size, additional_param would have the shape [batch size, feat_dim]. But when using 3 GPUs, result2.shape becomes [3 * (batch size), feat_dim] (instead of [3, batch size, feat_dim]), which causes issues during loss computation.

A simple solution could be to use result2[:(batch size)] in the loss calculation, but I am not sure if this is the correct approach. Could you suggest an appropriate way to resolve this issue?

Your explanation of the root cause for this issue sounds correct as the global batch size will be divided by the number of GPUs creating the local batch size inside the forward method. However, parameters etc. are usually not initialized using the batch size so could you explain why you are using the batch size in dim0? Are you computing a per-batch loss? If so, outside of the forward method your model output should have the global batch size again.