Distributed data parallel for model where the batch size is changing

I am trying to implement FNO-MiONet architecture from this paper ([2303.04778] Fourier-MIONet: Fourier-enhanced multiple-input neural operators for multiphase modeling of geological carbon sequestration). In this architecture, the batch size changes during the forward pass. Below is the code for the network till branch-truck operation (please refer to Table 2 in the paper).

def forward(self, xfield, xscalars, xtime):
        batchsize_x = xfield.shape[0]
        size_x, size_y = xfield.shape[1], xfield.shape[2]
        size_t = xtime.shape[-1]
        # ==================================== MIOnet ====================================
        x_b1 = self.fc_b1(xfield) #[bs, nz, nx, width]
        x_b2 = self.fc_b2(xscalars) #[bs, width]
        x_branch = x_b1 + x_b2[:,None,None,:] #[bs, nz, nx, width]
        x_trunk = self.fc_t1(xtime) #[nt, width]
        x_branch = x_branch.unsqueeze(1)
        x_branch = x_branch.repeat(1, size_t, 1, 1, 1) #[bs, nt, nz, nx, width]

        # Reshape trunk output to match the dimensions of the branch output
        x_trunk = x_trunk.unsqueeze(0).unsqueeze(-2).unsqueeze(-2) #[1, nt, 1, 1, width]
        # multiply branch and trunk output
        x = x_branch * x_trunk
        x = x.reshape((batchsize_x*size_t, size_x, size_y, -1))
        # ==================================== Fourier layers====================================
        x = x.reshape((batchsize_x, size_t, size_x, size_y, -1))
        return x

The model works with a single GPU. However, I cannot use this model with the DataParallel option in PyTorch. If I use 4 GPUs with a batch size of four, the model’s output with the DataParallel option is just one sample and not 4.

for x, s, y in train_loader:
    x, s, y = x.to(device), s.to(device), y.to(device)
    t = train_t.to(device)
    pred = dp_model(x, s, t) 

When I tried to do some testing for my code with 4 GPUs and 4 batchsize for the train_loader, I got only 1 sample for the pred variable.

Just to confirm, are you using DataParallel or DistributedDataParallel. Changing the batch size for DistributedDataParallel should be okay. I’m not understanding why it is getting split into 1 sample.

I am actually using DataParallel. I am going to try DistributedDataParallel option also. I am not sure if DataParallel works with changing the batch size or not.