Function AddBackward0 returned an invalid gradient at index 1 - expected type torch.FloatTensor but got torch.cuda.FloatTensor

My Model:

    
class myNet(nn.Module):
    def __init__(self):
        super(myNet,self).__init__()
        backbone=geffnet.efficientnet_b3(pretrained=True)
        act1=Dynamic_relu_b(40)
        act2=Dynamic_relu_b(1536)
      
        self.backbone=torch.nn.Sequential(
            backbone.conv_stem,
            backbone.bn1,
            act1,
            backbone.blocks,
            backbone.conv_head,
            backbone.bn2,
            act2,
            backbone.global_pool
        )
        
        self.global_avgpool=torch.nn.AdaptiveAvgPool2d(1)
        self.global_bn=nn.BatchNorm1d(1536)
        self.global_bn.bias.requires_grad=False
        self.local_conv=nn.Conv2d(1536,512,1)
        self.local_bn=nn.BatchNorm2d(512)
        self.local_bn.bias.requires_grad=False
        self.fc=nn.Linear(1536,20)
        nn.init.kaiming_normal_(self.fc.weight,mode='fan_out')
        nn.init.constant_(self.fc.bias,0)

    def forward(self,x):
        x=self.backbone(x)

        global_feat=self.global_avgpool(x)
        global_feat=global_feat.view(global_feat.shape[0],-1)
        global_feat=F.dropout(global_feat,p=0.2)
        global_feat=self.global_bn(global_feat)
        global_feat=l2_norm(global_feat)
    
        local_feat=torch.mean(x,-1,keepdim=True)
        local_feat=self.local_bn(self.local_conv(local_feat))
        local_feat=local_feat.squeeze(-1).permute(0,2,1)
        local_feat=l2_norm(local_feat,axis=-1)

        out=self.fc(global_feat)*16
        return global_feat,local_feat,out
def one_hot_smooth_label(x,num_class,smooth=0.1):
    num=x.shape[0]
    labels=torch.zeros((num,20))
    for i in range(num):
        labels[i][x[i]]=1
    labels=(1-(num_class-1)/num_class*smooth)*labels+smooth/num_class
    return labels
images=torch.rand((4,3,300,300))
images=images.cuda().float()
labels=torch.from_numpy(np.array([1,0,0,1]))
model=myNet()
model=model.cuda()
global_feat,local_feat,cls_score=model(images)
# global_feat=global_feat.to('cpu')
# local_feat=local_feat.to('cpu')
cls_score=cls_score.to('cpu')
labels=one_hot_smooth_label(labels,20)
criterion=nn.BCEWithLogitsLoss()

loss=criterion(cls_score,labels)
loss.backward()

This the error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-39-951a8a08a40d> in <module>
     20 
     21 loss=criterion(cls_score,labels)
---> 22 loss.backward()

/opt/conda/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
---> 99         allow_unreachable=True)  # allow_unreachable flag
    100 
    101 

RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type torch.FloatTensor but got torch.cuda.FloatTensor

After many experiments I found the problem:

#if we delete act1 and act2 in model,the error will not happen

class myNet(nn.Module):
    def __init__(self):
        super(myNet,self).__init__()
        backbone=geffnet.efficientnet_b3(pretrained=True)
        act1=Dynamic_relu_b(40)
        act2=Dynamic_relu_b(1536)
      
        self.backbone=torch.nn.Sequential(
            backbone.conv_stem,
            backbone.bn1,
            #act1,
            backbone.blocks,
            backbone.conv_head,
            backbone.bn2,
            #act2,
            backbone.global_pool
        )
        
        self.global_avgpool=torch.nn.AdaptiveAvgPool2d(1)
        self.global_bn=nn.BatchNorm1d(1536)
        self.global_bn.bias.requires_grad=False
        self.local_conv=nn.Conv2d(1536,512,1)
        self.local_bn=nn.BatchNorm2d(512)
        self.local_bn.bias.requires_grad=False
        self.fc=nn.Linear(1536,20)
        nn.init.kaiming_normal_(self.fc.weight,mode='fan_out')
        nn.init.constant_(self.fc.bias,0)

    def forward(self,x):
        x=self.backbone(x)

        global_feat=self.global_avgpool(x)
        global_feat=global_feat.view(global_feat.shape[0],-1)
        global_feat=F.dropout(global_feat,p=0.2)
        global_feat=self.global_bn(global_feat)
        global_feat=l2_norm(global_feat)
    
        local_feat=torch.mean(x,-1,keepdim=True)
        local_feat=self.local_bn(self.local_conv(local_feat))
        local_feat=local_feat.squeeze(-1).permute(0,2,1)
        local_feat=l2_norm(local_feat,axis=-1)

        out=self.fc(global_feat)*16
        return global_feat,local_feat,out

This is the Dynamic_relu code:


class Residual(nn.Module):
    def __init__(self, in_channel, R=8, k=2):
        super(Residual, self).__init__()
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.relu = nn.ReLU(inplace=True)
        self.R = R
        self.k = k
        out_channel = int(in_channel / R)
        self.fc1 = nn.Linear(in_channel, out_channel)
        fc_list = []
        for i in range(k):
            fc_list.append(nn.Linear(out_channel, 2 * in_channel))
        self.fc2 = nn.ModuleList(fc_list)

    def forward(self, x):
        x = self.avg(x)
        x = torch.squeeze(x)
        x = self.fc1(x)
        x = self.relu(x)
        result_list = []
        for i in range(self.k):
            result = self.fc2[i](x)
            result = 2 * torch.sigmoid(result) - 1
            result_list.append(result)
        return result_list


class Dynamic_relu_b(nn.Module):
    def __init__(self, inchannel, R=8, k=2):
        super(Dynamic_relu_b, self).__init__()
        self.lambda_alpha = 1
        self.lambda_beta = 0.5
        self.R = R
        self.k = k
        self.init_alpha = torch.zeros(self.k)
        self.init_beta = torch.zeros(self.k)
        self.init_alpha[0] = 1
        self.init_beta[0] = 1
        for i in range(1, k):
            self.init_alpha[i] = 0
            self.init_beta[i] = 0

        self.residual = Residual(inchannel)

    def forward(self, input):
        delta = self.residual(input)
        in_channel = input.shape[1]
        bs = input.shape[0]
        alpha = torch.zeros((self.k, bs, in_channel))
        beta = torch.zeros((self.k, bs, in_channel))
        for i in range(self.k):
            for j, c in enumerate(range(0, in_channel * 2, 2)):
                alpha[i, :, j] = delta[i][:, c]
                beta[i, :, j] = delta[i][:, c + 1]
        alpha1 = alpha[0]
        beta1 = beta[0]
        max_result = self.dynamic_function(alpha1, beta1, input, 0)
        for i in range(1, self.k):
            alphai = alpha[i]
            betai = beta[i]
            result = self.dynamic_function(alphai, betai, input, i)
            max_result = torch.max(max_result, result)
        return max_result

    def dynamic_function(self, alpha, beta, x, k):
        init_alpha = self.init_alpha[k]
        init_beta = self.init_beta[k]
        alpha = init_alpha + self.lambda_alpha * alpha
        beta = init_beta + self.lambda_beta * beta
        bs = x.shape[0]
        channel = x.shape[1]
        results = torch.zeros_like(x)
        for i in range(bs):
            for c in range(channel):
                results[i, c, :, :] = x[i, c] * alpha[i, c] + beta[i, c]
        return results

How can I solve this problem?

Hi Beilei!

The short answer is that you are mixing gpu and cpu tensors together.
(But I do not know why commenting out act1 and act2 makes your
error (appear to) go away.)

Your images live in the gpu.

Your labels live in the cpu.

Your model lives in the gpu.

Gpu images passed through gpu model. Okay.

cls_score now lives in the cpu.

labels still live in the cpu because one_hot_smooth_label() returns
a cpu tensor.

Cpu cls_score and cpu labels passed to BCEWithLogitsLoss.
Okay. (And loss lives in the cpu.)

But you are trying to backpropagate a cpu loss through a gpu model.
So somewhere inside, your tensors don’t match.

(I don’t know why this “fixes” the error, but I doubt it’s the real issue.)

Probably the simplest thing would be to leave cls_score in the gpu,
and move labels to the gpu. Thus:

# cls_score=cls_score.to('cpu')
labels=one_hot_smooth_label(labels,20)
labels = labels.cuda()
criterion=nn.BCEWithLogitsLoss()
loss=criterion(cls_score,labels)
loss.backward()

Now cls_score, labels, and loss should all be consistently in the
gpu, so you should be able to backpropagate your gpu loss through
your gpu model.

Good luck.

K. Frank

Thank you very much for your kind answer, which is indeed reasonable. But unfortunately it doesn’t seem to have solved the problem

class myNet(nn.Module):
    def __init__(self):
        super(myNet,self).__init__()
        self.act1=Dynamic_relu_b(64)
        self.conv1=nn.Conv2d(3,64,3)
        self.pool=nn.AdaptiveAvgPool2d(1)
        self.fc=nn.Linear(128,20)
    def forward(self,x):
        x=self.conv1(x)
        x=self.act1(x)
        x=self.pool(x)
        x=x.view(x.shape[0],-1)
        x=self.fc(x)
     

        return x
output=model(images)
labels=one_hot_smooth_label(labels,20)
labels = labels.cuda()
criterion=nn.BCEWithLogitsLoss()

loss=criterion(output,labels)
loss.backward()
RuntimeError                              Traceback (most recent call last)
<ipython-input-42-1268777e87e6> in <module>()
     21 
     22 loss=criterion(output,labels)
---> 23 loss.backward()

1 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type TensorOptions(dtype=float, device=cpu, layout=Strided, requires_grad=false) but got TensorOptions(dtype=float, device=cuda:0, layout=Strided, requires_grad=false) (validate_outputs at /pytorch/torch/csrc/autograd/engine.cpp:484)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x46 (0x7fcf7711b536 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x2d84224 (0x7fcfb1bad224 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&) + 0x548 (0x7fcfb1baed58 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&, bool) + 0x3d2 (0x7fcfb1bb0ce2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::Engine::thread_init(int) + 0x39 (0x7fcfb1ba9359 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::python::PythonEngine::thread_init(int) + 0x38 (0x7fcfbe2e8378 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0xbd6df (0x7fcfe23416df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #7: <unknown function> + 0x76db (0x7fcfe34236db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #8: clone + 0x3f (0x7fcfe375c88f in /lib/x86_64-linux-gnu/libc.so.6)

Again, once act1 is removed, the error disappears!

So the error should be in the code of Dynamic_relu

Hi Beilei!

Well, at least you got a different error.

Note, you’re using a different version of python than in your first post
(and perhaps a different version of pytorch, as well). This makes it
harder to compare your two runs.

Good point!

From your Dynamic_relu_b code:

Those torch.zeros() are also cpu tensors. So you will still have a
cpu-gpu mismatch.

I don’t know what the standard way is to write an activation so that it
“automatically” moves to the gpu when you call model.cuda(), but
I don’t think you get this behavior for free.

Things you could try:

First, run your whole model on the cpu, just to make sure it works. So
leave Dynamic_relu_b as it is, and don’t move your model, images,
or labels to the gpu.

If that works, you might go back to the gpu. The easiest way might
be to write a cuda version of Dynamic_relu_b, where you move its
internal tensors to the gpu.

And if that works, try to figure out how to make Dynamic_relu_b work
“automatically” on both the cpu and gpu. After all, the various built-in
pytorch Modules know how to do this, so there must be some standard
way. (But I don’t know what it is.)

Good luck.

K. Frank

I asked the question on stackoverflow and got a very good answer, which I hope will help you too

answer

Your intuition is right, and you’re close to the right answer