Input and weights on different GPUs after 1 epoch when model sharding

I have read posts about model sharding and have found really good examples.
However, when I implemented it, at the 2nd epoch, I got the following error:

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 0 does not equal 1 (while checking arguments for cudnn_convolution)

My code is like:

class MyModel(nn.Module):
    def __init__(self, some_param):
        self.large_submodule1 = SubModel1(...)
        self.large_submodule2 = SubModel2(...)

        self.large_submodule1.cuda(0)
        self.large_submodule2.cuda(1)

    def forward(self, x):
        x1, x2, cat = self.large_submodule1(x)
        x1 = x1.cuda(1)
        x2 = x2.cuda(1)
        cat = cat.cuda(1)
        out = self.large_submodule2(x1, x2, cat)
        return out.cuda(0) # because the ground truth is saved at cuda(0)

class SubModel1(nn.Module):
    def __init__(self, some_param):
        self.conv1 = ...
        self.conv2 = ...
        self.sigmoid = 

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        cat = torch.cat((x1, x2), dim=1)
        return x1, x2, cat

class SubModel2(nn.Module):
    def __init__(self, some_param):
        self.conv1 = ...
        self.conv2 = ...
        self.sigmoid = 

    def forward(self, x1, x2, cat):
        x1 = self.conv1(x1)
        x2 = self.conv2(x2)
        out = torch.cat((x1, x2, cat), dim=1)
        return out

It worked at the 1st epoch, but throwed the above error at the 2nd epoch.
I’ve read this issue about a similar error, but I didn’t use DataParallel and my batch size was 1.

Thank you for your time!

Did you call to() on the model after the first epoch, e.g. in your validation loop?
Could you post a code snippet to reproduce this issue?

Oh yes. Thank you for your reply! I re-checked the main function and found that my model was moved to CPU to save checkpoints, and then it was moved to one GPU as a whole.
Now after every saving operation I change the device of each submodule. It worked for 1 batch, but when I further utilized DataParallel, as your post suggested, it throwed the following error:

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

My code now is like:

class MyModel(nn.Module):
    def __init__(self, some_param):
        self.large_submodule1 = SubModel1(...)
        self.large_submodule2 = SubModel2(...)

        if len(gpu_ids) == 2:
              self.large_submodule1.cuda(0)
              self.large_submodule2.cuda(1)
        elif len(gpu_ids) == 4:
              self.large_submodule1 = nn.DataParallel(self.large_submodule1, device_ids=[0, 1]).to('cuda:0')
              self.large_submodule2 = nn.DataParallel(self.large_submodule2, device_ids=[2, 3]).to('cuda:2')

    def forward(self, x):
        x1, x2, cat = self.large_submodule1(x) # error occurs here
        device = x1.device.index # don't know if there's a better way to do this
        x1 = x1.cuda(device+2)
        x2 = x2.cuda(device+2)
        cat = cat.cuda(device+2)
        out = self.large_submodule2(x1, x2, cat)
        return out.cuda(device) 

class Seg(BaseModel)
    def initialize(self, opts, **kwargs):
        self.net = MyModel(some_param)
    def save_network(self, path):
        torch.save(self.net.cpu().state_dict(), path+'/model.pth')
        if len(gpu_ids) == 2:
            self.net.large_submodule1.cuda(0)
            self.net.large_submodule2.cuda(1)
        elif len(gpu_ids) == 4:
            self.net.large_submodule1 = nn.DataParallel(self.net.large_submodule1, device_ids=[0, 1]).to('cuda:0')
            self.net.large_submodule2 = nn.DataParallel(self.net.large_submodule2, device_ids=[2, 3]).to('cuda:2')

Now if I use 2 GPUs it works fine. But if I use 4 GPUs, after saving a checkpoint, it throws the above error at the next epoch. Could you help me with this? Many thanks!

I figured it out. It should be:

            self.net.large_submodule1 = self.net.large_submodule1.to('cuda:0')
            self.net.large_submodule2 = self.net.large_submodule2.to('cuda:2')