How to change pretrained model from 3 RGB Channels to 4 Channels without causing non- leaf tensor value error?

Hello guys, I have been trying to change the pretrained PyTorch Densenet’s first conv layer from 3 channels to 4 channels while maintaining its original RGB channel’s pretrained weights. I have done the following codes, but the optimizer part throws me this error: "ValueError: can't optimize a non-leaf Tensor" .

 import torchvision.models as models
    import torch.nn as nn
    backbone = models.__dict__['densenet169'](pretrained=True)
    
    
    weight1 = backbone.features.conv0.weight.data.clone()
    new_first_layer  = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    with torch.no_grad():
        new_first_layer.weight[:,:3] = weight1
    
    backbone.features.conv0 = new_first_layer
    optimizer = torch.optim.SGD(backbone.parameters(), 0.001,
                                     weight_decay=0.1)  # Changing this optimizer from SGD to ADAM

I have also tried to remove the argument with torch.no_grad(): but this issue still remains:

      ValueError                                Traceback (most recent call last)
    <ipython-input-343-5fc87352da04> in <module>()
         11 backbone.features.conv0 = new_first_layer
         12 optimizer = torch.optim.SGD(res.parameters(), 0.001,
    ---> 13                                  weight_decay=0.1)  # Changing this optimizer from SGD to ADAM
    
    ~/anaconda3/envs/detectron2/lib/python3.6/site-packages/torch/optim/sgd.py in __init__(self, params, lr, momentum, dampening, weight_decay, nesterov)
         66         if nesterov and (momentum <= 0 or dampening != 0):
         67             raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    ---> 68         super(SGD, self).__init__(params, defaults)
         69 
         70     def __setstate__(self, state):
    
    ~/anaconda3/envs/detectron2/lib/python3.6/site-packages/torch/optim/optimizer.py in __init__(self, params, defaults)
         50 
         51         for param_group in param_groups:
    ---> 52             self.add_param_group(param_group)
         53 
         54     def __getstate__(self):
    
    ~/anaconda3/envs/detectron2/lib/python3.6/site-packages/torch/optim/optimizer.py in add_param_group(self, param_group)
        231                                 "but one of the params is " + torch.typename(param))
        232             if not param.is_leaf:
    --> 233                 raise ValueError("can't optimize a non-leaf Tensor")
        234 
        235         for name, default in self.defaults.items():
    
    ValueError: can't optimize a non-leaf Tensor

My PyTorch version is: 1.7.0.

Could you guys please help? Thanks alot!

Regards.

Hey, I copied your code and it ran fine.

What?! Really, could you please let me know what your pytorch version is? Thanks!

Here is a working solution for ResNet:

1 Like

Guys, I have resolved this. Here is the code I did:

import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable
backbone = models.__dict__['densenet169'](pretrained=True)
weight1 = backbone.features.conv0.weight.clone()
new_first_layer  = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).requires_grad_()
new_first_layer.weight[:,:3,:,:].data[...] =  Variable(weight1, requires_grad=True)
backbone.features.conv0 = new_first_layer
optimizer = torch.optim.SGD(res.parameters(), 0.001,
                                 weight_decay=0.1)
1 Like