Accessing intermediate layers of a pretrained network vs using the whole network gives different outputs

So I’ve been working with Audio AI for a while. I was experimenting with the Panns model. My goal was to modify it such that it could be trained on additional classes and could predict more classes in addition to what it was predicting before. The version of Panns that I am using has an attention layer at the end. I simply added another attention layer with output units = number of extra classes and trained only this new layer. For predictions, I combined the output from the two.

Now here’s the code for my model that finally worked:

class CustomPanns(nn.Module):
    def __init__(self, sample_rate: int, window_size: int, hop_size: int,
                 mel_bins: int, fmin: int, fmax: int,act_classes: int, classes_num: int, apply_aug: bool,wts_path: str, top_db=None):
        super().__init__()

        self.bn0 = nn.BatchNorm2d(mel_bins)

        self.fc1 = nn.Linear(1024, 1024, bias=True)
        self.attn = AttBlock(1024, act_classes, activation='sigmoid')
        self.backbone = PANNsDense121Att(sample_rate, window_size, hop_size,
                                         mel_bins, fmin, fmax, classes_num, apply_aug,top_db)
        checkpoint = torch.load(wts_path)

        self.backbone.load_state_dict(checkpoint["model"])
        for param in self.backbone.parameters():
            param.requires_grad=False


    def forward(self,input_data):
        input_x, mixup_lambda = input_data
        """
        Input: (batch_size, data_length)"""
        b, c, s = input_x.shape



        #frame_shape = framewise_output.shape
        #clip_shape = clipwise_output.shape
        #combined = torch.cat([clipwise_output,clip_bkbone],dim=1)
        #combined = combined/torch.sum(combined)
        '''output_dict = {
            'framewise_output': framewise_output.reshape(b, c, frame_shape[1], frame_shape[2]),
            'clipwise_output': clipwise_output.reshape(b, c, clip_shape[1]),
            "combined_output": combined.reshape(b,c,combined.shape[1]),
            "backbone_output": clip_bkbone.reshape(b,c,clip_bkbone.shape[1])
        }'''
        features,output_dict = self.backbone.forward(input_data)


        (clipwise_output, norm_att, segmentwise_output) = self.attn(features)
        #segmentwise_output = segmentwise_output.transpose(1, 2)

        #framewise_output = interpolate(segmentwise_output,
        #                              self.backbone.interpolate_ratio)
        #framewise_output = pad_framewise_output(framewise_output, frames_num)
        #print(output_dict["clipwise_output"].shape)
        #print(clipwise_output.shape)
        combined = torch.cat([output_dict["clipwise_output"],clipwise_output.unsqueeze(1)],axis=2)
        output_dict["combined_output"] = combined.reshape(b,c,combined.shape[2])
        return output_dict


Funny thing is, initially I had written a loop where I was accessing my backbone layers a little differently in the forward method. i.e instead of directly calling self.backbone.forward, I was in a way rewriting the forward method of the backbone model i.e features = self.backbone.cnn_featureextractor and so on and so forth operating on the features ahead. What I realized is while the weights were the same and the model was frozen, in this case I got a completely different output as compared to the above method. Just wanted to know the reason for that:

Code for the backbone model:

class PANNsDense121Att(nn.Module):
    def __init__(self, sample_rate: int, window_size: int, hop_size: int,
                 mel_bins: int, fmin: int, fmax: int, classes_num: int, apply_aug: bool, top_db=None):
        super().__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        self.interpolate_ratio = 32  # Downsampled ratio
        self.apply_aug = apply_aug

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(
            sr=sample_rate,
            n_fft=window_size,
            n_mels=mel_bins,
            fmin=fmin,
            fmax=fmax,
            ref=ref,
            amin=amin,
            top_db=top_db,
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(
            time_drop_width=64,
            time_stripes_num=2,
            freq_drop_width=8,
            freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(mel_bins)

        self.fc1 = nn.Linear(1024, 1024, bias=True)
        self.att_block = AttBlock(1024, classes_num, activation='sigmoid')

        self.densenet_features = models.densenet121(pretrained=False).features

        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)

    def cnn_feature_extractor(self, x):
        x = self.densenet_features(x)
        return x

    def preprocess(self, input_x, mixup_lambda=None):

        x = self.spectrogram_extractor(input_x)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.apply_aug:
            x = self.spec_augmenter(x)

        return x, frames_num

    def forward(self, input_data):
        input_x, mixup_lambda = input_data
        """
        Input: (batch_size, data_length)"""
        b, c, s = input_x.shape
        input_x = input_x.reshape(b * c, s)
        x, frames_num = self.preprocess(input_x, mixup_lambda=mixup_lambda)
        if mixup_lambda is not None:
            b = (b * c) // 2
            c = 1
        # Output shape (batch size, channels, time, frequency)
        x = x.expand(x.shape[0], 3, x.shape[2], x.shape[3])
        x = self.cnn_feature_extractor(x)

        # Aggregate in frequency axis
        x = torch.mean(x, dim=3)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       self.interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)
        frame_shape = framewise_output.shape
        clip_shape = clipwise_output.shape
        output_dict = {
            'framewise_output': framewise_output.reshape(b, c, frame_shape[1], frame_shape[2]),
            'clipwise_output': clipwise_output.reshape(b, c, clip_shape[1]),
        }

        return x,output_dict