Using strict=False for fine-tuning

I have a model and after my model trained on a dataset, I use the following function (LoadStateDictCustom) that I implemented in my model class to load the the stored weights for fine-tuning:

class YourModel(nn.Module):
def __init__(self, n_classes, in_chann):
    super(YourModel, self).__init__()


def LoadStateDictCustom(self,StateDicPath):
    StatDic=torch.load(StateDicPath,map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    StatDic2=OrderedDict()
    for key,value in StatDic.items():
        if(key=='OLD key name that must be changed  '):
            StatDic2['NEW key name for second time']=value
        else:
            StatDic2[key]=value
    self.load_state_dict(StatDic2)

def forward(self, x,trg=None):
     pass

But, when before starting fine-tuning, I need to make a change in some of the layers and I need to load the weights by applying strict=False. Unfortunately, I do not have any idea that where I should apply strict=False. In train script or in the model? Any idea?

strict=False can be passed to the load_state_dict call, but be careful to check for incompatible keys afterwards as no error will be raised if keys are missing or generally incompatible.
Here is a small example:

# setup
model = models.resnet18()
sd = model.state_dict()

# make sure we can load the state_dict
model.load_state_dict(sd)
# <All keys matched successfully>

# manipulate the model by adding a new but unused layer
model.new_layer = nn.Linear(10, 10)

# this will now fail
model.load_state_dict(sd)
# RuntimeError: Error(s) in loading state_dict for ResNet:
# 	Missing key(s) in state_dict: "new_layer.weight", "new_layer.bias". 

# use strict=False as a workaround and alternative to add the missing key to the state_dict
model.load_state_dict(sd, strict=False)
# _IncompatibleKeys(missing_keys=['new_layer.weight', 'new_layer.bias'], unexpected_keys=[])

# create new model which is completely invalid
model = nn.Linear(10, 10)
# this is not loading anything!
model.load_state_dict(sd, strict=False)
#'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.conv2.weight', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', 'layer2.0.bn2.running_mean', 'layer2.0.bn2.running_var', 'layer2.0.bn2.num_batches_tracked', 'layer2.0.downsample.0.weight', 'layer2.0.downsample.1.weight', 'layer2.0.downsample.1.bias', 'layer2.0.downsample.1.running_mean', 'layer2.0.downsample.1.running_var', 'layer2.0.downsample.1.num_batches_tracked', 'layer2.1.conv1.weight', 'layer2.1.bn1.weight', 'layer2.1.bn1.bias', 'layer2.1.bn1.running_mean', 'layer2.1.bn1.running_var', 'layer2.1.bn1.num_batches_tracked', 'layer2.1.conv2.weight', 'layer2.1.bn2.weight', 'layer2.1.bn2.bias', 'layer2.1.bn2.running_mean', 'layer2.1.bn2.running_var', 'layer2.1.bn2.num_batches_tracked', 'layer3.0.conv1.weight', 'layer3.0.bn1.weight', 'layer3.0.bn1.bias', 'layer3.0.bn1.running_mean', 'layer3.0.bn1.running_var', 'layer3.0.bn1.num_batches_tracked', 'layer3.0.conv2.weight', 'layer3.0.bn2.weight', 'layer3.0.bn2.bias', 'layer3.0.bn2.running_mean', 'layer3.0.bn2.running_var', 'layer3.0.bn2.num_batches_tracked', 'layer3.0.downsample.0.weight', 'layer3.0.downsample.1.weight', 'layer3.0.downsample.1.bias', 'layer3.0.downsample.1.running_mean', 'layer3.0.downsample.1.running_var', 'layer3.0.downsample.1.num_batches_tracked', 'layer3.1.conv1.weight', 'layer3.1.bn1.weight', 'layer3.1.bn1.bias', 'layer3.1.bn1.running_mean', 'layer3.1.bn1.running_var', 'layer3.1.bn1.num_batches_tracked', 'layer3.1.conv2.weight', 'layer3.1.bn2.weight', 'layer3.1.bn2.bias', 'layer3.1.bn2.running_mean', 'layer3.1.bn2.running_var', 'layer3.1.bn2.num_batches_tracked', 'layer4.0.conv1.weight', 'layer4.0.bn1.weight', 'layer4.0.bn1.bias', 'layer4.0.bn1.running_mean', 'layer4.0.bn1.running_var', 'layer4.0.bn1.num_batches_tracked', 'layer4.0.conv2.weight', 'layer4.0.bn2.weight', 'layer4.0.bn2.bias', 'layer4.0.bn2.running_mean', 'layer4.0.bn2.running_var', 'layer4.0.bn2.num_batches_tracked', 'layer4.0.downsample.0.weight', 'layer4.0.downsample.1.weight', 'layer4.0.downsample.1.bias', 'layer4.0.downsample.1.running_mean', 'layer4.0.downsample.1.running_var', 'layer4.0.downsample.1.num_batches_tracked', 'layer4.1.conv1.weight', 'layer4.1.bn1.weight', 'layer4.1.bn1.bias', 'layer4.1.bn1.running_mean', 'layer4.1.bn1.running_var', 'layer4.1.bn1.num_batches_tracked', 'layer4.1.conv2.weight', 'layer4.1.bn2.weight', 'layer4.1.bn2.bias', 'layer4.1.bn2.running_mean', 'layer4.1.bn2.running_var', 'layer4.1.bn2.num_batches_tracked', 'fc.weight', 'fc.bias'])

Note that the last code part does not load anything from the sd, which is indicated in the returned _IncompatibleKeys. However, if you don’t explicitly check it you might easily introduce errors to your code assuming the state_dict was (partially) loaded.

Dear @ptrblck Many thanks for the answer. How I can use the
LoadStateDictCustom() as I mentioned in the question while using strict = False?

Pass this argument to load_state_dict as seen in my code snippet.

Unfortunately when I replace the ‘LoadStateDictCustom’ I faced with the error that it does not contain ‘strict’. How I can modify it?

I don’t know what exactly you are replacing, so could you post a minimal and executable code snippet showing the error you are running into, please?

@ptrblck For fine-tuning I use the following scripts in my training code:

    weight_name='weight.pt'
    file_weight=os.path.join('output','model_weights',subfolder,'name1',weight_name)
    
    optim_name='adam.pt'
    file_optimizer=os.path.join('output','model_weights',subfolder,'name1',optim_name)
    
    model = modelName(pretrained=encoder_pretrained)
    
    model.requires_grad_(False)
    model.last_conv.requires_grad_(True)
   
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

    dev = device 
    model.to(device)

    

    # loading file weight (fine-tuning)
    model.LoadStateDictCustom(file_weight, strict=False)   #***************
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=2e-7)
    
    
    # loading optimizer (fine-tuning)
    optimizer.load_state_dict(torch.load(file_optimizer))
    torch.backends.cudnn.benchmark = True

    #torch.cuda.set_device('cuda:2')
    criterion = KLDLoss1vs1('cuda')    
    model.train()
    

But unfortunately I face with the following error regarding that line of scripts that I distinguished it with stars:

model.LoadStateDictCustom() got an unexpected keyword argument 'strict'

You would need to add the strict argument to your function definition in LoadStateDictCustom(self,StateDicPath, strict) and then pass it to the actual load_state_dict method.