Transfer learning with different inputs

I have a similar problem: I have a pretrained CNN14 model which I want to use a base model for an audio detection task in a multitask framework. I need to expand the in_channel values of the first block layer (4, 7) for both subnetworks.

Below is the  Pretrained CNN14 model:

from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_utils import do_mixup, interpolate, pad_framewise_output
 

def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x
class Cnn14(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 
        fmax, classes_num):
        
        super(Cnn14, self).__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, pad_mode=pad_mode, 
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 
            n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset)
    
    def change_first_layer(m,self):
        for name, child in m.named_children():
              if isinstance(child,ConvBlock ): #nn.Conv2d
                kwargs = {
                'out_channels': child.out_channels,
                'kernel_size': child.kernel_size,
                'stride': child.stride,
                'padding': child.padding,
                'bias': False if child.bias == None else True
        }
              m._modules[self.conv_block1] = ConvBlock(4, **kwargs)
              return True
        else:
            if(change_first_layer(child)):
             return True
        return False
 
    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)
        
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        
        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)
        
        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))
        
        output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return output_dict


class Transfer_Cnn14(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 
        fmax, classes_num, freeze_base):
        """Classifier for a new task using pretrained Cnn14 as a sub module.
        """
    
        super(Transfer_Cnn14, self).__init__()
        audioset_classes_num = 13 #527
        
        self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, 
            fmax, audioset_classes_num)

        # Transfer to another task layer
       #self.fc_transfer = nn.Linear(2048, classes_num, bias=True)

        if freeze_base:
            # Freeze AudioSet pretrained layers
            for param in self.base.parameters():
                param.requires_grad = False

        self.init_weights()

    def init_weights(self):
        init_layer(self.fc_transfer)

    def load_from_pretrain(self, pretrained_checkpoint_path):
        checkpoint = torch.load(pretrained_checkpoint_path)
        self.base.load_state_dict(checkpoint['model'])

    def forward(self, input, mixup_lambda=None):
        """Input: (batch_size, data_length)
        """
        output_dict = self.base(input, mixup_lambda)
        embedding = output_dict['embedding']

        clipwise_output =  torch.log_softmax(self.fc_transfer(embedding), dim=-1)
        output_dict['clipwise_output'] = clipwise_output
 
        return output_dict`

I tired to transfer the base model(Pretrained) into the task:
class EINV2(nn.Module):
def init(self, cfg, dataset):
super().init()
self.pe_enable = False # Ture | False
self.in_channels_sed = 4
self.in_channels_doa = 7
if cfg[‘data’][‘audio_feature’] == ‘logmel&intensity’:
self.f_bins = cfg[‘data’][‘n_mels’]
# self.in_channels_doa = 7
# self.in_channels_sed = 4

    self.sed = nn.Sequential(
        Transfer_Cnn14(in_channels = 4,  classes_num = 14, freeze_base = False),
          nn.AvgPool2d(kernel_size=(2, 2))

    )
    # self.sed = (Transfer_Cnn14(4,  classes_num = 14, freeze_base = False),
    #       nn.AvgPool2d(kernel_size=(2, 2))
    # )
    self.doa= nn.Sequential(
         Transfer_Cnn14(7,  classes_num = 14, freeze_base = False),
            nn.AvgPool2d(kernel_size=(2, 2))
    )
    if self.pe_enable:
       
     self.pe = PositionalEncoding(pos_len=100, d_model=512, pe_type='t', dropout=0.0)
    self.sed_trans_track1 = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
    self.sed_trans_track2 = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
    self.doa_trans_track1 = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
    self.doa_trans_track2 = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)

    self.fc_sed_track1 = nn.Linear(512, 14, bias=True)
    self.fc_sed_track2 = nn.Linear(512, 14, bias=True)
    self.fc_doa_track1 = nn.Linear(512, 3, bias=True)
    self.fc_doa_track2 = nn.Linear(512, 3, bias=True)
    self.final_act_sed = nn.Sequential() # nn.Sigmoid()
    self.final_act_doa = nn.Tanh()

    self.init_weight()
    if  freeze_base:
        # Freeze AudioSet pretrained layers
        for param in self.base.parameters():
            param.requires_grad = False

        self.init_weights()

def init_weights(self):
    init_layer(self) #.fc_transfer

def load_from_pretrain(self, pretrained_checkpoint_path):
    checkpoint = torch.load(pretrained_checkpoint_path)
    self.base.load_state_dict(checkpoint['model'])

def forward(self, input, mixup_lambda=None):
    """Input: (batch_size, data_length)
    """
    output_dict = self.base(input, mixup_lambda)
    embedding = output_dict['embedding']
def init_weight(self):

    init_layer(self.fc_sed_track1)
    init_layer(self.fc_sed_track2)
    init_layer(self.fc_doa_track1)
    init_layer(self.fc_doa_track2)


def forward(self, x):
    """
    x: waveform, (batch_size, num_channels, data_length)
    """
    x_sed = x[:, :4]
    x_doa = x