Error with DataParallel

Hi everyone - really would appreciate your help on this. I’ve been trying to speed up my resnet (basically resnet34, I customized it to expand the features by adding an additional block at the end). To this end, I added nn.DataParallel as a wrapper for my model. The code below worked entirely fine prior to the additions of nn.DataParallel (and using model.cuda() and inputs.cuda() vs to(device) which I used prior). The first run of this model, which is done with only the final output layer and batchnorm layers unfrozen, works fine. I then take this model and unfreeze all weights in all layers of the network, and run this model through the training loop once again. This produces the following error. I’m really not sure how to fix this, although I’m guessing it likely has to do with using the previously trained model and putting that through another training loop (inside of which I again call model.cuda()). I’ve pasted as much of my code as I thought would be relevant, including the training loop and calls to partial unfreezing. Would really appreciate your help and thanks again!

Traceback (most recent call last):
File “/wynton/protected/home/ichs/dmandair/BRCA/train_val.py”, line 395, in
trained_model = train_model(trained_model, criterion, optimizer_ft, exp_lr_scheduler, dataloader_dict, device, NUM_EPOCHS)
File “/wynton/protected/home/ichs/dmandair/BRCA/train_val.py”, line 155, in train_model
outputs = model(inputs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.
9/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/_utils.py”, line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.
9/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/_utils.py”, line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/BRCA/train_val.py”, line 222, in forward
x = self.base(x)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.
9/site-packages/torch/nn/modules/container.py”, line 141, in forward
input = module(input)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/conv.py”, line 446, in forward
return self._conv_forward(input, self.weight, self.bias)
File “/wynton/protected/home/ichs/dmandair/anaconda3/envs/ONC_EXP/lib/python3.9/site-packages/torch/nn/modules/conv.py”, line 442, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

def train_model(model, criterion, optimizer, scheduler, dataloaders, device, num_epochs=4):
     since = time.time()
     model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
     model = model.cuda()
     best_model_wts = copy.deepcopy(model.state_dict())
     best_acc = 0.0
     for epoch in range(num_epochs):
          print('Epoch {}/{}'.format(epoch, num_epochs - 1))
          print('-' * 10)
    # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
             if phase == 'train':
                 model.train()  # Set model to training mode
             else:
                 model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels, _, _ in dataloaders[phase]:
                inputs = inputs.cuda()
                labels = labels.cuda()
                # zero the parameter gradients
                optimizer.zero_grad()
            # forward
            # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_NORM)
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

                epoch_loss = running_loss / len(train_tile_ds)
                epoch_acc = running_corrects.double() / len(train_tile_ds)
            else:
                epoch_loss = running_loss / len(val_tile_ds)
                epoch_acc = running_corrects.double() / len(val_tile_ds)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

def get_new_base(orig_model):
    return nn.Sequential(*list(orig_model.children())[:-2])


def create_new_block(in_channel, out_channel):
    conv_layer = nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), bias=False)
    batch_layer = nn.BatchNorm2d(out_channel, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    relu = nn.ReLU(inplace=True)
    average_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
    return [conv_layer, batch_layer, relu, average_pool]

class CustomResNet(nn.Module):
    def __init__(self, out_c, out_feat):
        super().__init__()
        base = models.resnet34(pretrained=True)

        base_channels_op = base.fc.in_features
        self.base = get_new_base(base)
        self.new_layers = nn.Sequential(*create_new_block(base_channels_op, out_c))
        self.final = nn.Linear(out_c, out_feat)

    def forward(self, x):
        x = self.base(x)
        x = self.new_layers(x)
        x = torch.flatten(x,1)
        return self.final(x)

def set_grad(m, b):
    if isinstance(m, (nn.Linear,nn.BatchNorm2d)): return
    if hasattr(m, 'weight'):
        for p in m.parameters(): p.requires_grad_(b)


train_loader = DataLoader(train_tile_ds, batch_size=BS,
                                       shuffle = True, num_workers=4)
validation_loader = DataLoader(val_tile_ds, batch_size=BS, shuffle = False, num_workers=4)

dataloader_dict = {'train': train_loader, 'val': validation_loader}

mod_resnet = CustomResNet(768, 2)

mod_resnet.apply(partial(set_grad, b=False))

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD([{'params': mod_resnet.base.parameters(), 'lr': 1e-2},
            {'params': mod_resnet.new_layers.parameters(), 'lr': 1e-3},
            {'params': mod_resnet.final.parameters(), 'lr': 1e-3}
        ], momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

trained_model = train_model(mod_resnet, criterion, optimizer_ft, exp_lr_scheduler, dataloader_dict, device, NUM_EPOCHS)

optimizer_ft = optim.SGD([{'params': trained_model.module.base.parameters(), 'lr': 1e-3},
            {'params': trained_model.module.new_layers.parameters(), 'lr': 1e-3},
            {'params': trained_model.module.final.parameters(), 'lr': 1e-3}
        ], momentum=0.9)

trained_model = train_model(trained_model, criterion, optimizer_ft, exp_lr_scheduler, dataloader_dict, device, NUM_EPOCHS)

Check the .device attribute of the model parameters as well as the input data in the forward method of your model and try to isolate which tensor creates the device mismatch.
Once this is done, check how this tensor/parameter is created and if you are manually pushing this tensor to a specific device instead of letting nn.DataParallel handle it.