I have a similar problem: I have a pretrained CNN14 model which I want to use a base model for an audio detection task in a multitask framework. I need to expand the in_channel values of the first block layer (4, 7) for both subnetworks.
Below is the Pretrained CNN14 model:
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_utils import do_mixup, interpolate, pad_framewise_output
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class Cnn14(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num):
super(Cnn14, self).__init__()
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
win_length=window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
freq_drop_width=8, freq_stripes_num=2)
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
self.fc1 = nn.Linear(2048, 2048, bias=True)
self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
init_layer(self.fc_audioset)
def change_first_layer(m,self):
for name, child in m.named_children():
if isinstance(child,ConvBlock ): #nn.Conv2d
kwargs = {
'out_channels': child.out_channels,
'kernel_size': child.kernel_size,
'stride': child.stride,
'padding': child.padding,
'bias': False if child.bias == None else True
}
m._modules[self.conv_block1] = ConvBlock(4, **kwargs)
return True
else:
if(change_first_layer(child)):
return True
return False
def forward(self, input, mixup_lambda=None):
"""
Input: (batch_size, data_length)"""
x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
# Mixup on spectrogram
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}
return output_dict
class Transfer_Cnn14(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num, freeze_base):
"""Classifier for a new task using pretrained Cnn14 as a sub module.
"""
super(Transfer_Cnn14, self).__init__()
audioset_classes_num = 13 #527
self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, audioset_classes_num)
# Transfer to another task layer
#self.fc_transfer = nn.Linear(2048, classes_num, bias=True)
if freeze_base:
# Freeze AudioSet pretrained layers
for param in self.base.parameters():
param.requires_grad = False
self.init_weights()
def init_weights(self):
init_layer(self.fc_transfer)
def load_from_pretrain(self, pretrained_checkpoint_path):
checkpoint = torch.load(pretrained_checkpoint_path)
self.base.load_state_dict(checkpoint['model'])
def forward(self, input, mixup_lambda=None):
"""Input: (batch_size, data_length)
"""
output_dict = self.base(input, mixup_lambda)
embedding = output_dict['embedding']
clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1)
output_dict['clipwise_output'] = clipwise_output
return output_dict`
I tired to transfer the base model(Pretrained) into the task:
class EINV2(nn.Module):
def init(self, cfg, dataset):
super().init()
self.pe_enable = False # Ture | False
self.in_channels_sed = 4
self.in_channels_doa = 7
if cfg[‘data’][‘audio_feature’] == ‘logmel&intensity’:
self.f_bins = cfg[‘data’][‘n_mels’]
# self.in_channels_doa = 7
# self.in_channels_sed = 4
self.sed = nn.Sequential(
Transfer_Cnn14(in_channels = 4, classes_num = 14, freeze_base = False),
nn.AvgPool2d(kernel_size=(2, 2))
)
# self.sed = (Transfer_Cnn14(4, classes_num = 14, freeze_base = False),
# nn.AvgPool2d(kernel_size=(2, 2))
# )
self.doa= nn.Sequential(
Transfer_Cnn14(7, classes_num = 14, freeze_base = False),
nn.AvgPool2d(kernel_size=(2, 2))
)
if self.pe_enable:
self.pe = PositionalEncoding(pos_len=100, d_model=512, pe_type='t', dropout=0.0)
self.sed_trans_track1 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.sed_trans_track2 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.doa_trans_track1 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.doa_trans_track2 = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=1024, dropout=0.2), num_layers=2)
self.fc_sed_track1 = nn.Linear(512, 14, bias=True)
self.fc_sed_track2 = nn.Linear(512, 14, bias=True)
self.fc_doa_track1 = nn.Linear(512, 3, bias=True)
self.fc_doa_track2 = nn.Linear(512, 3, bias=True)
self.final_act_sed = nn.Sequential() # nn.Sigmoid()
self.final_act_doa = nn.Tanh()
self.init_weight()
if freeze_base:
# Freeze AudioSet pretrained layers
for param in self.base.parameters():
param.requires_grad = False
self.init_weights()
def init_weights(self):
init_layer(self) #.fc_transfer
def load_from_pretrain(self, pretrained_checkpoint_path):
checkpoint = torch.load(pretrained_checkpoint_path)
self.base.load_state_dict(checkpoint['model'])
def forward(self, input, mixup_lambda=None):
"""Input: (batch_size, data_length)
"""
output_dict = self.base(input, mixup_lambda)
embedding = output_dict['embedding']
def init_weight(self):
init_layer(self.fc_sed_track1)
init_layer(self.fc_sed_track2)
init_layer(self.fc_doa_track1)
init_layer(self.fc_doa_track2)
def forward(self, x):
"""
x: waveform, (batch_size, num_channels, data_length)
"""
x_sed = x[:, :4]
x_doa = x