DataParallel expected on same device issue

I have DataParallel code as follows

model = Model(args)
model = nn.DataParallel(model, device_ids = [0,1])
model = model.cuda()

for data in train_data:
    data = data_to_cuda(data)
    predicted_output = model(data)
    loss = compute_loss(predicted_output, data['labels])

Now I am getting error

return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

I went to the line where this error was found and it occurred in the forward call of one of the submodules of the model, which is as follows

class ImageEncoder:
    def __init__(self, args):
        resnet_model = torchvision.models.resnet18()
        self.model = torch.nn.Sequential(*(list(resnet_model.children())[:-1]))
    
    def forward(self, x):
        x = x.float()
        output = self.model(x)
        return output 

I printed out the following things, after entering the forward call of the ImageEncoder

    def forward(self, x):
        print('x', x.device)
        x = x.float()
        print('x.float', x.device)
        print('model', self.model[4][0].conv1.weight.device)
        output = self.model(x).squeeze(-1).squeeze(-1)
        print('output', output.device)
        return output

I got following output

x cuda:0
x cuda:1
x.float cuda:1
model cuda:0
x.float cuda:0
model cuda:0
output cuda:0

It seems that ImageEncoder is only copied to one device. Can someone please explain me what is wrong with my code.

Your code works fine for me after adding the missing nn.Module base class in ImageEncoder:

import torch
import torch.nn as nn
import torchvision

class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        resnet_model = torchvision.models.resnet18()
        self.model = torch.nn.Sequential(*(list(resnet_model.children())[:-1]))

    def forward(self, x):
        x = x.float()
        output = self.model(x)
        return output

model = ImageEncoder()
model = nn.DataParallel(model, device_ids = [0,1])
model = model.cuda()

x = torch.randn(2, 3, 224, 224).cuda()
out = model(x)
print(out.shape)
# torch.Size([2, 512, 1, 1])

Thanks for your timely response.
Let me give you more context:

My code is actually this:

class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        resnet_model = torchvision.models.resnet18()
        self.model = torch.nn.Sequential(*(list(resnet_model.children())[:-1]))

    def forward(self, x):
        print('x', x.device)
        x = x.float()
        print('x.float', x.device)
        print('model', self.model[4][0].conv1.weight.device)
        output = self.model(x).squeeze(-1).squeeze(-1)
        print('output', output.device)
        return output

class VideoEncoder(nn.Module):
    def __init__(self):
        super(VideoEncoder, self).__init__()
        self.model =  ImageEncoder()

    def forward(self, x):
        x = self.model(x.flat(0,1))
        return x

class FinalModel(nn.Module):
    def __init__(self):
        supe(FinalModel, self).__init__()
        self.model = VideoEncoder()
    
    def forward(self, x):
        return self.model(x)


  def data_to_cuda(data):
      for d in data:
          if(type(data[d]) == dict):
              for p in data[d]:
                  data[d][p] = data[d][p].cuda()
          else:
              data[d] = data[d].cuda()
      return data

model = FinalModel()
model= nn.DataParallel(model, device_ids=[0,1])
model = model.cuda()

loss = torch.tensor(0.0).cuda()
losses = []

optimizer = Adam(model.parameters(), lr = 0.0001)

for data in train_data:
    data = data_to_cuda(data)
    predicted_labels = model(data)
    loss_i = compute_loss(data['labels'], predicted_labels)
    loss = loss + loss_i
      
    if(self.steps % backprop_every == 0):
        loss = loss / self.backprop_every
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        loss = torch.tensor(0.0).cuda()

        losses.append(loss.item())  
    steps += 1
            

and finally I am running code as

python train.py

Your code is unfortunately still not executable as e.g. the data is undefined (and not replaced by random data), x.flat is invalid, supe(FinalModel... is a typo etc.

However, after fixing these issues, the code still works:

import torch
import torch.nn as nn
import torchvision


class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        resnet_model = torchvision.models.resnet18()
        self.model = torch.nn.Sequential(*(list(resnet_model.children())[:-1]))

    def forward(self, x):
        print('x', x.device)
        x = x.float()
        print('x.float', x.device)
        print('model', self.model[4][0].conv1.weight.device)
        output = self.model(x).squeeze(-1).squeeze(-1)
        print('output', output.device)
        return output

class VideoEncoder(nn.Module):
    def __init__(self):
        super(VideoEncoder, self).__init__()
        self.model =  ImageEncoder()

    def forward(self, x):
        x = self.model(x)
        return x

class FinalModel(nn.Module):
    def __init__(self):
        super(FinalModel, self).__init__()
        self.model = VideoEncoder()

    def forward(self, x):
        return self.model(x)


model = FinalModel()
model= nn.DataParallel(model, device_ids=[0,1])
model = model.cuda()


loss = torch.tensor(0.0).cuda()
losses = []

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

x = torch.randn(2, 3, 224, 224).cuda()
out = model(x)
print(out.shape)

Output:

x cuda:0
x cuda:1
x.float cuda:0
x.float cuda:1
model cuda:0
model cuda:1
output cuda:1
output cuda:0
torch.Size([2, 512])