When using loss.backward() then get error : Function MulBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:1

I ran the code using dataparellel and got the error.
My demo code is as follows:

import math
import torch
import torch.nn as nn


class EnergyConstrainedADMM_P_imagenet(nn.Module):
    def __init__(self, net_model, width_ub, blocks, scale=None):
        super(EnergyConstrainedADMM_P_imagenet, self).__init__()

        if scale != None:
            width_ub = [width_ub[0]] + [int(i * scale) for i in width_ub[1:-1]] + [width_ub[-1]]
            print(width_ub)

        self.net = net_model
        self.s = nn.Parameter(torch.FloatTensor(width_ub))
        self.t1 = nn.Parameter(torch.rand(1, blocks[0]))
        self.t2 = nn.Parameter(torch.rand(1, blocks[1]))
        self.t3 = nn.Parameter(torch.rand(1, blocks[2]))
        self.t4 = nn.Parameter(torch.rand(1, blocks[3]))
        self.t = [self.t1, self.t2, self.t3, self.t4]   #建立多个1维parameter

        self.tau = 5.0

    # 代码来源:TAS SoftSelect.py
    def select2withP(self, logits, tau):
        if tau <= 0:
            probs = nn.functional.softmax(logits, dim=1)
        else:
            while True:  # a trick to avoid the gumbels bug
                gumbels = -torch.empty_like(logits).exponential_().log()
                new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
                probs = nn.functional.softmax(new_logits, dim=1)
                if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (
                not torch.isnan(probs).any()):
                    break
        return probs

    def set_tau(self, tau_max, tau_min, epoch_ratio):
        self.tau = tau_min + (tau_max - tau_min) * (1 + math.cos(math.pi * epoch_ratio)) / 2

    def forward(self, x):
        probs = []
        for i in range(len(self.t)):
            probs.append(self.select2withP(self.t[i], self.tau))

        features = []
        for i, layer in enumerate(self.net.layers):
            layer_i = layer(x)
            features.append(layer_i)

            # print('i:{}, layer_i:{}'.format(i, layer_i.shape))
            if i in self.net.info:
                choices = self.net.info[i]
                tensors = features[choices[0]:choices[-2] + 1]
                probs_i = probs[choices[-1]].view(-1)
                x = sum([tensor * w for tensor, w in zip(tensors, probs_i)])

            elif i == len(self.net.layers) - 2:
                x = layer_i.view(layer_i.size(0), -1)  # avgpool
            else:
                x = layer_i

        return x


'''
bottleneck block, which is used for ResNet-50,101
'''


class Bottleneck(nn.Module):
    def __init__(self, inplanes, cfg, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(inplanes, cfg[0], kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv3 = nn.Conv2d(cfg[1], cfg[2], kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm2d(cfg[0])
        self.bn2 = nn.BatchNorm2d(cfg[1])
        self.bn3 = nn.BatchNorm2d(cfg[2])

        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if stride != 1 or inplanes != cfg[2]:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, cfg[2], kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(cfg[2]),
            )
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet_Bottleneck(nn.Module):

    def __init__(self, cfg, nums, Class_num=1000):
        self.cfg = cfg
        self.inplanes = 64
        super(ResNet_Bottleneck, self).__init__()

        self.layers = nn.ModuleList([nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )])

        strides = [1, 2, 2, 2]
        self.info = {}
        for i in range(4):  # 4 stage
            filters = cfg[sum(nums[:i]) * 3: sum(nums[:i + 1]) * 3]
            self.layers.append(Bottleneck(self.inplanes, filters[0:3], strides[i]))
            self.inplanes = filters[2]

            for num in range(1, nums[i]):
                self.layers.append(Bottleneck(self.inplanes, filters[num * 3:(num + 1) * 3]))
                self.inplanes = filters[num * 3 + 2]

            self.info[sum(nums[:i + 1])] = list(range(sum(nums[:i]) + 1, sum(nums[:i + 1]) + 1)) + [i]

        self.layers.append(nn.AvgPool2d(7, stride=1))
        self.layers.append(nn.Linear(cfg[-1], Class_num))

        # print(self.info)

        # 初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            layer_i = layer(x)
            if i == len(self.layers) - 2:
                x = layer_i.view(layer_i.size(0), -1)  # avgpool
            else:
                x = layer_i

        return x

    def get_bn_weights(self):
        res = []
        for m in self.modules():
            if isinstance(m, Bottleneck):
                res += [m.bn1.weight, m.bn2.weight, m.bn3.weight]
        return res

    def bn_update(self, args):
        for m in self.modules():
            if isinstance(m, Bottleneck):
                m.bn1.weight.grad.data.add_(args.bn_sparsity * torch.sign(m.bn1.weight.data))  # L1
                m.bn2.weight.grad.data.add_(args.bn_sparsity * torch.sign(m.bn2.weight.data))
                m.bn3.weight.grad.data.add_(args.bn_sparsity * torch.sign(m.bn3.weight.data))


def resnet50(cfg=None):
    if cfg is None:
        cfg = [[64, 64, 256] * 3, [128, 128, 512] * 4, [256, 256, 1024] * 6, [512, 512, 2048] * 3]
        cfg = [item for sub_list in cfg for item in sub_list]

    nums = [3, 4, 6, 3]
    model = ResNet_Bottleneck(cfg, nums)

    return model


#get model
model = resnet50().cuda()
width_ub = [3, 64] + resnet50().cfg + [1000]
blocks = [3, 4, 6, 3]
primal_model = EnergyConstrainedADMM_P_imagenet(model, width_ub, blocks).cuda()
primal_model = torch.nn.DataParallel(primal_model)

#test
img = torch.rand(4,3,224,224).cuda()
out = model(img)
out = primal_model(img)
target = torch.tensor([1,1,1,1]).cuda()

criterion = nn.CrossEntropyLoss()
loss = criterion(out, target)
loss.backward()

the error is as follows:

Traceback (most recent call last):
  File "demo.py", line 210, in <module>
    loss.backward()
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function MulBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:2

Hi,

Could you try to enable anomaly mode to get a better idea of where the error occurs please?
You can do this by adding at the beginning of your script torch.autograd.set_detect_anomaly(True).

Thanks ! I add the torch.autograd.set_detect_anomaly(True) and then get the error as follows:


/pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "/home/lth/anaconda3/lib/python3.7/threading.py", line 890, in _bootstrap
    self._bootstrap_inner()
  File "/home/lth/anaconda3/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/lth/anaconda3/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "demo.py", line 58, in forward
    x = sum([tensor * w for tensor, w in zip(tensors, probs_i)])
  File "demo.py", line 58, in <listcomp>
    x = sum([tensor * w for tensor, w in zip(tensors, probs_i)])

Traceback (most recent call last):
  File "demo.py", line 210, in <module>
    loss.backward()
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/lth/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function MulBackward0 returned an invalid gradient at index 1 - expected device cuda:0 but got cuda:3

The error happens at forward of EnergyConstrainedADMM_P_imagenet, but this error just happens when using torch.nn.DataParallel.

Looking forward to your reply, thanks !

Seems to be to blame here. But if the elements are on different GPU, it should fail during the forward, never the backward.

After investigation, a minimal repro is:

import torch

torch.autograd.set_detect_anomaly(True)

a = torch.rand([], requires_grad=True, device="cuda:0")
b = torch.rand(10, requires_grad=True, device="cuda:1")

c = a * b
c.sum().backward() # Fails with the same error

I’ll open a github issue for that as soon as github is back online :smiley:

To fix your code, you should make sure that every element is on the right device. In particular, here, it seems like probs is not on the right GPU after giving it to the DataParallel. a simple fix would be probs_i = probs_i.to(tensors[0].device).

Great ! ! ! Thansk ! :smiley:

For reference, here it is: https://github.com/pytorch/pytorch/issues/33870

I have got similar issue - same error at the same function but my code I try to run on cpu

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B", device_map='auto', offload_folder="offload_folder")

then I wrap it with peft and try to use for finetuning

    src_tokens = tokenizer(orig, padding=True, truncation=True, max_length=512, return_tensors='pt').to('cpu')
    tokenizer.src_lang = 'eng_Hebr'
    tgt_tokens = tokenizer(target, padding=True, truncation=True, max_length=512, return_tensors='pt').to('cpu')
    tgt_tokens.input_ids[tgt_tokens.input_ids == tokenizer.pad_token_id] = -100


    loss = model(**src_tokens, labels=tgt_tokens.input_ids).loss
    loss.backward()

I tried also without .to(‘cpu’) but it didn’t help. I always get the error about expected ‘meta’
any advise. It is first time Im trying to use pytorch so sory if this is basic

Could you post the error message you are receiving as I’m unsure if your error is related to this topic.