Question on loading the pretrained weights in another trained model

I work with videos (input shape [batch,channels,frames,height,width]) and I designed a video classification model in which I used pretrained ResNet50 that extract features from each frame separately (2D convolution) and then the model stacks the extracted features from all the frames. I trained my model on a video classification dataset and now I want to use my trained model for another task. In this regard, my question is that since I loaded the pretrained weights of the ResNet50 in my model, for using my trained new model and loading its checkpoints in another task should I also load the weights of the ResNet50 in my script?

import os
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn.functional as F
from collections import OrderedDict
num_class = 25
class BasicConv3d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super(BasicConv3d, self).__init__()
        self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        
            
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
    

class Interpolate_1(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate_1, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = 'bilinear'
        
    def forward(self, x):
        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
        return x


len_temporal = 16

__all__ = ['Spatial_Encoder']

class Spatial_Encoder(nn.Module):
    def __init__(self, pretrained=True):
        super(Spatial_Encoder, self).__init__()
        
        self.fc_1 = nn.Sequential(nn.Conv3d(2048, 25, kernel_size=1, stride=1, bias=True))
        if torch.cuda.is_available():
            self.fc_1.cuda()
        
        
        
        self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        if torch.cuda.is_available():
            self.featureExtractor.cuda()
        
        self.activation = {}
        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output.detach()
            return hook

    
        self.featureExtractor.layer4.register_forward_hook(get_activation('layer4'))
        
        self.conv_t_1 = BasicConv3d(2048, 2048, kernel_size=(3,1,1), stride=(2,1,1), padding=(1,0,0))
        if torch.cuda.is_available():
            self.conv_t_1.cuda()
                
        self.conv_t_2 = BasicConv3d(2048, 2048, kernel_size=(3,1,1), stride=(2,1,1), padding=(1,0,0))
        if torch.cuda.is_available():
            self.conv_t_2.cuda()
        self.conv_t_3 = BasicConv3d(2048, 2048, kernel_size=(3,1,1), stride=(2,1,1), padding=(1,0,0))
        if torch.cuda.is_available():
            self.conv_t_3.cuda()
        self.conv_t_4 = BasicConv3d(2048, 2048, kernel_size=(3,1,1), stride=(2,1,1), padding=(1,0,0))
        if torch.cuda.is_available():
            self.conv_t_4.cuda()

        
        self.fc = nn.Sequential(nn.Conv3d(2048, num_class, kernel_size=1, stride=1, bias=True),)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))          
        
    def forward_features(self, x,len_temporal):   
        
        image_s = []
        features_4 = []

        for i in range(len_temporal):
            image_s = x[:,:,i:i+1,:,:]
            
            frame = self.featureExtractor(image_s.squeeze(2))
            layer4_output_Temporal_att = self.activation['layer4']
            features_4.append(layer4_output_Temporal_att)                                 
        return torch.stack(features_4,dim=2)




    def forward(self, x):
        # x.shape : [batch_size,channels,frames,height,width]
        
        Spatial_features_4 = self.forward_features(x,len_temporal = x.size(2))   # [batch, 2048, 16, 4, 4]
        
        y = self.conv_t_1(Spatial_features_4.cuda())                      #([batch, 2048, 8, 4, 4])
        y = self.conv_t_2(y)                                              # [batch , 2048, 4, 4, 4])
        y = self.conv_t_3(y)                                              # [batch, 2048, 2, 4, 4])
        y = self.conv_t_4(y)                                              # [batch, 2048, 1, 4, 4])

        y = F.avg_pool3d(y, (1, y.size(3), y.size(4)), stride=1)         # ([batch, 2048, 1, 1, 1])

        y = self.fc_1(y)                                                 # ([batch, 25, 1, 1, 1])
             
        y = y.view(y.size(0), y.size(1), y.size(2))                      # ([batch, 25, 1])
        
        logits = torch.mean(y, 2)                                        # ([batch, 25])
        return logits

If you’ve already trained your model and stored the trained state_dict, you could just load it directly without initializing the model first with the ImageNet pre-trained model and overwrite the parameters afterwards.

@ptrblck , Thank you for your answer. I’ve already trained my model. As I understood from your answer, When I want to use and load the weights of my model, I should not load the weights of the ResNet (i.e., ResNet50_Weights) but I did not get the point on how to overwrite the parameters after that. Could you please let me know that?

Moreover, I stored the weights of my model with the extension of “.pyth” and on the other side , during training, I loaded the weights of the ResNet in my code as follows:

from torchvision.models import resnet50, ResNet50_Weights
self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

I have no idea what is the extension of the weights of the resnet50(e.g., .pth or .pt or etc.) Will it cause problems if the weights of my model (.pyth) mismatches with the file extension of the ResNet?

The file extension won’t matter as you can use any name you want.
You can load your state_dict via model.load_state_dict(torch.load(PATH_TO_STATE_DICT)).

Will " model.load_state_dict(torch.load(PATH_TO_STATE_DICT))" overwrite the parameters or further action is needed?

Yes, load_state_dict will load and overwrite all parameters stored in the state_dict, which match the attribute names of the parameters and buffers of the model.

Thank you so much for the point!