RuntimeError: Function CatBackward returned an invalid gradient at index 1 - expected device 1 but got 0

I am testing using Pytorch with multiple GPUs. I am not using DataParallel, but I want to use Model Parallelism.
My Model design is two input branches (each in separate GPUs). I have done an example with MNIST for reproducibility.
But at training, I get the exception mentioned below. Any help would be appreciated.

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Hyperparameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

DATA_PATH = '/data/'
MODEL_STORE_PATH = '/models/'

# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# MNIST dataset
train_dataset = datasets.MNIST(root=DATA_PATH, train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

gpu1 = torch.device("cuda:0")
gpu2 = torch.device("cuda:1")

class DistConvNet(nn.Module):
    def __init__(self):
        super(DistConvNet, self).__init__()
        
        # gpu1
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer1.to(gpu1)
        self.layer2.to(gpu1)
        
        # gpu2
        self.layer3 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer4 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3.to(gpu2)
        self.layer4.to(gpu2)
        
        self.drop_out = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64 * 2, 1000)
        self.fc2 = nn.Linear(1000, 10)
        
        self.drop_out.to(gpu1)
        self.fc1.to(gpu1)
        self.fc2.to(gpu1)
        
    def forward(self, x1, x2):
        out1 = self.layer1(x1)
        out1 = self.layer2(out1)
        
        out2 = self.layer3(x2)
        out2 = self.layer4(out2)
        
        out1 = out1.reshape(out1.size(0), -1)
        out2 = out2.reshape(out2.size(0), -1)
        out2.to(gpu1)
        
        
        out = torch.cat((out1, out2), 1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

model_dist = DistConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_dist.parameters(), lr=learning_rate)


total_step = len(train_loader)
loss_list = []
acc_list = []
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Run the forward pass
        images_gpu1, images_gpu2, labels = images.to(gpu1), images.to(gpu2), labels.to(gpu1)
        
        outputs = model_dist(images_gpu1, images_gpu2)
        loss = criterion(outputs, labels)
        loss_list.append(loss.item())

        # Backprop and perform Adam optimisation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track the accuracy
        total = labels.size(0)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        acc_list.append(correct / total)

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item(),
                          (correct / total) * 100))

Exception:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-27984b0eb824> in <module>
     13         # Backprop and perform Adam optimisation
     14         optimizer.zero_grad()
---> 15         loss.backward()
     16         optimizer.step()
     17 

/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    100                 products. Defaults to ``False``.
    101         """
--> 102         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    103 
    104     def register_hook(self, hook):

/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 

RuntimeError: Function CatBackward returned an invalid gradient at index 1 - expected device 1 but got 0

Thanks.

Your code looks generally alright besides this line of code:

out2.to(gpu1)

While nn.Modules are transferred inplace, you have have to assign tensors back:

out2 = out2.to(gpu1)

Could you fix this line and see it it’s working?

Thanks. It is fixed with this.

Hi,
I am getting the same error while using Model parallelism . Below is my code.

import torch
from torch import nn
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=1,
        depth=5,
        wf=6,
        padding=False,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597

        Using the default arguments will yield the exact version used
        in the original paper

        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList().cuda(0)
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList().cuda(1)
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1).cuda(2)

    def forward(self, x):
        blocks = []
        x=x.cuda(0)
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)
        x=x.cuda(1)
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])
        x=x.cuda(2)
        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block).cuda(0)

    def forward(self, x):
        x=x.cuda(0)
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2).cuda(1)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            ).cuda(1)

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        x=x.cuda(1)
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out
optimizer.zero_grad()
 x, y = batch
 y_pred = model(x.reshape(bs,1,image_size,image_size).float())
 label=y.reshape(bs,1,image_size,image_size).cuda(2).float()
 loss_fn = nn.BCEWithLogitsLoss()
 dice = f_score(y_pred, label)
 iou1=iou(y_pred, label)
 loss_fn(y_pred, label).backward()
 optimizer.step()

Error:

<ipython-input-5-fc64eb051240> in process_function(engine, batch)
      8     dice = f_score(y_pred, label)
      9     iou1=iou(y_pred, label)
---> 10     loss_fn(y_pred, label).backward()
     11     optimizer.step()
     12     #print("train", loss.item(), dice.item(), iou1.item())

~/anaconda3/envs/fm/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    105                 products. Defaults to ``False``.
    106         """
--> 107         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    108 
    109     def register_hook(self, hook):

~/anaconda3/envs/fm/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     91     Variable._execution_engine.run_backward(
     92         tensors, grad_tensors, retain_graph, create_graph,
---> 93         allow_unreachable=True)  # allow_unreachable flag
     94 
     95 

RuntimeError: Function CatBackward returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:1```