Transfer learning "from scratch"?

Hey folks,

I am new to transfer learning, and I am working on a segmentation task using PT/ MONAI. I downloaded some pre-trained weights and now, I’d like to use them to freeze some layers on those weights and update the weihts of the remaining layers. My goal is to compare, if a fully frozen network (except final conv layer), a network with only first, or first and second layer frozen, result in better/ worse performance.

I am unsure, if the way I do it is correct. Mostly, I am unsure about the fact, if I have actually loaded the weights and the model is pretrained. And I am unsure, if the remain ing, unfrozen weights are actually learning.

My segmentation is for shoulder, while the pretrained model is for the spleen.

# set the flag, it will change the initial model object 
self.PRETRAINED = True

def model(self, pretrained):
    # Init model 
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=(16, 32, 64, 128, 256), 
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(self.device)
    
    # Use pretrained weights
    if pretrained:          
        file = r'C:\Users\BKeoh\Desktop\spleen_ct_segmentation_v0.5.3\spleen_ct_segmentation\models\model.pt'
        model.load_state_dict(torch.load(file))
  
        # Pattern to match the parameter names
        pattern = 'model.0.*'
  
        # Iterate through model parameters
        for name, param in model.named_parameters():
            # Check if the parameter name matches the pattern using fnmatch
            if fnmatch.fnmatch(name, pattern):
                param.requires_grad = False
                    
        print(f'Using pretrained model with {len(pattern)} frozen layers')
    
    else:            
        print('Using randomly init. model')
  
    return model

Here, I am ssuming to use the pattern = 'model.0.*' to freeze the weights of the first layer of the Unet encoder.

Thanks in advance for any hints!
Cheers.

Your code looks correct assuming you have checked the pattern matching.
The model loading code also looks correct as you would see an error if the state_dict doesn’t match the model architecture.

1 Like