Problem in training on multi-GPU with DataParallel

Hello everyone,
I’m getting trouble using DataParallel in my multi-task learning model.
First, I wrapped my model using DataParallel to train my model on 8 GPUs,

# Define the network.
import pretrainedmodels
class CNN1(nn.Module):
    def __init__(self):
        super(CNN1, self).__init__()
        self.model = pretrainedmodels.__dict__["resnet34"](pretrained="imagenet")
        self.fc1 = nn.Linear(512, 8)  #For age class
        self.fc2 = nn.Linear(512, 3)    #For gender class
        self.fc3 = nn.Linear(512, 3)    #For race class

    def forward(self, x):
        bs, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
        label1 = self.fc1(x)
        label2 = self.fc2(x)
        label3 = self.fc3(x)
        return {'situation': label1, 'position': label2, 'direction': label3}

net = CNN1()

# Define if device is cuda
net = nn.DataParallel(net, device_ids = [0, 1, 2])
net.to(0)


criterion_1 = nn.CrossEntropyLoss()
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss()

# 4. Define the optimizer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 5. Setup the net to train
net.train()

Then, I started training with this network with the code below:

for epoch in tqdm(range(10)):
    for i, data in tqdm(enumerate(train_loader)):
        inputs = data["image"].to(0)
        situation_label = data["situation"].to(0)
        position_label = data["position"].to(0)
        direction_label = data["direction"].to(0)

        # forward to the net
        optimizer.zero_grad()
        outputs = net(inputs)

        situation_outputs = outputs['situation']
        position_outputs = outputs['position']
        direction_outputs = outputs['direction']

        loss_1 = criterion_1(situation_outputs, situation_label)
        loss_2 = criterion_2(position_outputs, position_label)
        loss_3 = criterion_3(direction_outputs, direction_label)

        loss = loss_1 + loss_2 + loss_3
        loss.backward()
        optimizer.step()

    # enter validation at evey 2 epochs
    if (epoch + 1) % 2 == 0:
        torch.save(net.state_dict(), "./checkpoints/UTKFaceCNN_" + str(epoch + 1) + ".pth")
        print(">>>Validating<<<")
        # save the model
        for j, data in tqdm(enumerate(val_loader)):
            inputs = data["image"].to(device=device)
            situation_label = data["situation"].to(device=device)
            position_label = data["position"].to(device=device)
            direction_label = data["direction"].to(device=device)
            output = net(inputs)

            loss_1 = criterion_1(situation_outputs, situation_label)
            loss_2 = criterion_2(position_outputs, position_label)
            loss_3 = criterion_3(direction_outputs, direction_label)

            loss = loss_1 + loss_2 + loss_3

I got this error message:

RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<ipython-input-20-c78d37c43833>", line 17, in forward
    x = self.model.features(x)
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/pretrainedmodels/models/torchvision_models.py", line 322, in features
    x = self.conv1(input)
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/.pyenv/versions/anaconda3-2020.02/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 420, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

Can anyone show me the solution to fix this prolem? And how to set GPU device correctly?

Thank you and best regards.

The modify_resnets method of pretrainedmodels seems to break nn.DataParallel.
If I remove this line of code in this library, the model works fine, and also the torchvision.models.resnet34 model works fine using your code.

My best guess is that the reassignment of the forward method could break it for some reason.