Input and weights on different GPUs after 1 epoch when model sharding

Oh yes. Thank you for your reply! I re-checked the main function and found that my model was moved to CPU to save checkpoints, and then it was moved to one GPU as a whole.
Now after every saving operation I change the device of each submodule. It worked for 1 batch, but when I further utilized DataParallel, as your post suggested, it throwed the following error:

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

My code now is like:

class MyModel(nn.Module):
    def __init__(self, some_param):
        self.large_submodule1 = SubModel1(...)
        self.large_submodule2 = SubModel2(...)

        if len(gpu_ids) == 2:
              self.large_submodule1.cuda(0)
              self.large_submodule2.cuda(1)
        elif len(gpu_ids) == 4:
              self.large_submodule1 = nn.DataParallel(self.large_submodule1, device_ids=[0, 1]).to('cuda:0')
              self.large_submodule2 = nn.DataParallel(self.large_submodule2, device_ids=[2, 3]).to('cuda:2')

    def forward(self, x):
        x1, x2, cat = self.large_submodule1(x) # error occurs here
        device = x1.device.index # don't know if there's a better way to do this
        x1 = x1.cuda(device+2)
        x2 = x2.cuda(device+2)
        cat = cat.cuda(device+2)
        out = self.large_submodule2(x1, x2, cat)
        return out.cuda(device) 

class Seg(BaseModel)
    def initialize(self, opts, **kwargs):
        self.net = MyModel(some_param)
    def save_network(self, path):
        torch.save(self.net.cpu().state_dict(), path+'/model.pth')
        if len(gpu_ids) == 2:
            self.net.large_submodule1.cuda(0)
            self.net.large_submodule2.cuda(1)
        elif len(gpu_ids) == 4:
            self.net.large_submodule1 = nn.DataParallel(self.net.large_submodule1, device_ids=[0, 1]).to('cuda:0')
            self.net.large_submodule2 = nn.DataParallel(self.net.large_submodule2, device_ids=[2, 3]).to('cuda:2')

Now if I use 2 GPUs it works fine. But if I use 4 GPUs, after saving a checkpoint, it throws the above error at the next epoch. Could you help me with this? Many thanks!