RuntimeError: Function CatBackward returned an invalid gradient at index 1 - expected device cuda:1 but got cuda:0

I am trying to change a public code so that I can have data parallel along two GPUs. I have seen this topic, but it is for Model Parallelism, rather than DataParallel.

The code for creating and training the encoder-decoder network is as follows (Please look at the lines marked by #MY CHANGE):

from __future__ import print_function
import argparse
import random
import torch
import torch.optim as optim
import sys
import os
sys.path.append(os.path.join(os.path.dirname(sys.path[0]), 'auxiliary'))
from datasetSMPL2 import *
from model_sample import *
from utils import *
from ply import *
import torch.nn as nn

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for')

opt = parser.parse_args()
# ========================================================== #
import os
sys.path.append(os.path.join(os.path.dirname(sys.path[0]), 'extension'))

blue = lambda x: '\033[94m' + x + '\033[0m'

opt.manualSeed = random.randint(1, 10000)  # fix seed
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
L2curve_train_smpl = []
L2curve_val_smlp = []

# meters to record stats on learning
train_loss_L2_smpl = AverageValueMeter()
val_loss_L2_smpl = AverageValueMeter()
tmp_val_loss = AverageValueMeter()

# ===================CREATE DATASET============================#
dataset = SMPL(train=True, regular = True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))
# ===================CREATE network==============================#
network = AE_AtlasNet_Humans()
gpu_num = torch.cuda.device_count()  # MY CHANGE
if torch.cuda.device_count() > 1:  # MY CHANGE
    print("Let's use", gpu_num, "GPUs!")  # MY CHANGE
    network = nn.DataParallel(network, list(range(gpu_num))).cuda() # MY CHANGE

network.apply(weights_init)  # initialization of the weight
# ===================CREATE optimizer=============================#
lrate = 0.001  # learning rate
optimizer = optim.Adam(network.parameters(), lr=lrate)
# =============start of the learning loop =============================== #
for epoch in range(0, opt.nepoch):
    if epoch==80:
        lrate = lrate/10.0  # learning rate scheduled decay
        optimizer = optim.Adam(network.parameters(), lr=lrate)
    if epoch==90:
        lrate = lrate/10.0  # learning rate scheduled decay
        optimizer = optim.Adam(network.parameters(), lr=lrate)

    # TRAIN MODE
    train_loss_L2_smpl.reset()
    network.train()
    for i, data in enumerate(dataloader, 0):
        optimizer.zero_grad()
        points, idx,_ = data
        points = points.transpose(2, 1).contiguous()
        points = points.cuda()
        pointsReconstructed = network(points, idx)  # MY CHANGE: replacing "network.forward_idx" with "network"
        # target = points.transpose(2, 1).contiguous().cuda(non_blocking=True)  # tried to use DataParallel with loss; does not fix the error!
        # criterion = nn.DataParallel(nn.MSELoss())
        # criterion.cuda()
        # loss_net = criterion(pointsReconstructed, target)
        # loss_net.backward(torch.Tensor([1, 1]).cuda())
        loss_net = torch.mean((pointsReconstructed - points.transpose(2, 1).contiguous()) ** 2)
        loss_net.backward()  # RUNTIMEERROR
        train_loss_L2_smpl.update(loss_net.item())
        optimizer.step()  # gradient update

The code for building the model is as follows (Please look at the lines marked by #MY CHANGE):

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import trimesh

class STN3d(nn.Module):
    def __init__(self, num_points = 2500):
        super(STN3d, self).__init__()
        self.num_points = num_points
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x,_ = torch.max(x, 2)
        x = x.view(-1, 1024)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class PointNetfeat(nn.Module):
    def __init__(self, num_points = 2500, global_feat = True, trans = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d(num_points = num_points)
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)

        self.bn1 = torch.nn.BatchNorm1d(64)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.bn3 = torch.nn.BatchNorm1d(1024)
        self.trans = trans

        self.num_points = num_points
        self.global_feat = global_feat

    def forward(self, x):
        batchsize = x.size()[0]
        if self.trans:
            trans = self.stn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans)
            x = x.transpose(2,1)
        x = F.relu(self.bn1(self.conv1(x)))
        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x,_ = torch.max(x, 2)
        x = x.view(-1, 1024)
        if self.trans:
            if self.global_feat:
                return x, trans
            else:
                x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
                return torch.cat([x, pointfeat], 1), trans
        else:
            return x


class PointGenCon(nn.Module):
    def __init__(self, bottleneck_size = 2500):
        self.bottleneck_size = bottleneck_size
        super(PointGenCon, self).__init__()

        self.conv1 = torch.nn.Conv1d(bottleneck_size, bottleneck_size, 1)
        self.conv2 = torch.nn.Conv1d(bottleneck_size, bottleneck_size//2, 1)
        self.conv3 = torch.nn.Conv1d(bottleneck_size//2, bottleneck_size//4, 1)
        self.conv4 = torch.nn.Conv1d(bottleneck_size//4, 3, 1)

        self.th = nn.Tanh()
        self.bn1 = torch.nn.BatchNorm1d(bottleneck_size)
        self.bn2 = torch.nn.BatchNorm1d(bottleneck_size//2)
        self.bn3 = torch.nn.BatchNorm1d(bottleneck_size//4)

    def forward(self, x):
        batchsize = x.size()[0]
        # print(x.size())
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = 2*self.th(self.conv4(x))
        return x


class AE_AtlasNet_Humans(nn.Module):
    def __init__(self, num_points = 6890, bottleneck_size = 1024, nb_primitives = 1):
        super(AE_AtlasNet_Humans, self).__init__()
        self.num_points = num_points
        self.bottleneck_size = bottleneck_size
        self.nb_primitives = nb_primitives
        self.encoder = nn.Sequential(
        PointNetfeat(num_points, global_feat=True, trans = False),
        nn.Linear(1024, self.bottleneck_size),
        nn.BatchNorm1d(self.bottleneck_size),
        nn.ReLU()
        )

        self.decoder = nn.ModuleList([PointGenCon(bottleneck_size = 3 +self.bottleneck_size) for i in range(0,self.nb_primitives)])

        import sys
        import os
        mesh = trimesh.load(os.path.join(os.path.dirname(sys.path[0]), 'data', 'template', 'male_template.ply'), process=False)
        self.mesh = mesh
        mesh_HR = trimesh.load(os.path.join(os.path.dirname(sys.path[0]), 'data', 'template', 'male_template_dense.ply'), process=False)
        self.mesh_HR = mesh_HR
        point_set = mesh.vertices

        bbox = np.array([[np.max(point_set[:,0]), np.max(point_set[:,1]), np.max(point_set[:,2])], [np.min(point_set[:,0]), np.min(point_set[:,1]), np.min(point_set[:,2])]])
        tranlation = (bbox[0] + bbox[1]) / 2
        point_set = point_set - tranlation

        point_set_HR = mesh_HR.vertices
        bbox = np.array([[np.max(point_set_HR[:,0]), np.max(point_set_HR[:,1]), np.max(point_set_HR[:,2])], [np.min(point_set_HR[:,0]), np.min(point_set_HR[:,1]), np.min(point_set_HR[:,2])]])
        tranlation = (bbox[0] + bbox[1]) / 2
        point_set_HR = point_set_HR - tranlation

        self.vertex = torch.from_numpy(point_set).cuda().float()
        self.vertex_HR = torch.from_numpy(point_set_HR).cuda().float()
        self.num_vertex = self.vertex.size(0)
        self.num_vertex_HR = self.vertex_HR.size(0)

    def forward(self, x, idx): # MY CHANGE: replacing "forward_idx" with "forward"
        x = self.encoder(x)
        outs = []
        for i in range(0,self.nb_primitives):
            idx = idx.view(-1)
            idx = idx.cpu().data.numpy().astype(np.int)
            rand_grid = self.vertex[idx,:]
            rand_grid = rand_grid.view(x.size(0), -1, 3).transpose(1,2).contiguous()
            rand_grid = Variable(rand_grid)
            y = x.unsqueeze(2).expand(x.size(0),x.size(1), rand_grid.size(2)).contiguous()
            y = torch.cat( (rand_grid, y), 1).contiguous()
            if x.is_cuda: # MY CHANGE: adding ".cuda()"
                y = y.cuda() # MY CHANGE: adding ".cuda()"
            outs.append(self.decoder[i](y))
        # res = torch.cat(outs,2).contiguous().transpose(2,1).contiguous().to("cuda:0") # tried to solve the RuntimeError; does not fix the error!
        # return res
        return torch.cat(outs,2).contiguous().transpose(2,1).contiguous()

I have a Runtime error as written in the title by running loss_net.backward(). I guess it might be related to the change I have made in the forward function of AE_Atlasnet_Humans (adding “.cuda()” to y). I did this change to fix an error that I had in forward pass when I wanted to use DataParallel.
I tired different things to fix the error in the backward pass, but they did not solve the error. You can find them in the comments.

Have you tried sending all variables to cuda in the same place in the code? I see that you are sending different variables to devices both outside the loop:

and then also inside the loop:

I can’t verify this but it’s possible that .cuda() (which I believe defaults to cuda:0) may be sending to different devices in different sections of the code, especially given that one is called within the DataParallel module and one is outside.

It’s also a good idea to use a debugger to check which device the variables used to calculate loss (pointsReconstructed, points.transpose(2, 1).contiguous()) are on: .contiguous() creates a copy of the tensor, for example, but on which device? A debugger also should help you check the gradients (or whether they exist in the first place).

1 Like

Thank you. Good suggestion. Having debugged the code and watched the tensors, I have noticed that rand_grid is on the different GPU which affects y as well. By fixing that, the problem got solved.

I replaced self.vertex = torch.from_numpy(point_set).cuda().float() with self.vertex = torch.from_numpy(point_set).float(), and rand_grid = self.vertex[idx,:] with rand_grid = self.vertex[idx,:].cuda(). I also deleted the following lines:

if x.is_cuda: # MY CHANGE: adding ".cuda()"
                y = y.cuda() # MY CHANGE: adding ".cuda()"
outs.append(self.decoder[i](y))

The reason behind this is that DataParallel devides the input tensors to two sets and copies each set to a different GPU. So, we always need to be sure that the tensors which are not the input arguments of the forward function are at the devise where the input arguments are.

1 Like