So I’ve been working with Audio AI for a while. I was experimenting with the Panns model. My goal was to modify it such that it could be trained on additional classes and could predict more classes in addition to what it was predicting before. The version of Panns that I am using has an attention layer at the end. I simply added another attention layer with output units = number of extra classes and trained only this new layer. For predictions, I combined the output from the two.
Now here’s the code for my model that finally worked:
class CustomPanns(nn.Module):
def __init__(self, sample_rate: int, window_size: int, hop_size: int,
mel_bins: int, fmin: int, fmax: int,act_classes: int, classes_num: int, apply_aug: bool,wts_path: str, top_db=None):
super().__init__()
self.bn0 = nn.BatchNorm2d(mel_bins)
self.fc1 = nn.Linear(1024, 1024, bias=True)
self.attn = AttBlock(1024, act_classes, activation='sigmoid')
self.backbone = PANNsDense121Att(sample_rate, window_size, hop_size,
mel_bins, fmin, fmax, classes_num, apply_aug,top_db)
checkpoint = torch.load(wts_path)
self.backbone.load_state_dict(checkpoint["model"])
for param in self.backbone.parameters():
param.requires_grad=False
def forward(self,input_data):
input_x, mixup_lambda = input_data
"""
Input: (batch_size, data_length)"""
b, c, s = input_x.shape
#frame_shape = framewise_output.shape
#clip_shape = clipwise_output.shape
#combined = torch.cat([clipwise_output,clip_bkbone],dim=1)
#combined = combined/torch.sum(combined)
'''output_dict = {
'framewise_output': framewise_output.reshape(b, c, frame_shape[1], frame_shape[2]),
'clipwise_output': clipwise_output.reshape(b, c, clip_shape[1]),
"combined_output": combined.reshape(b,c,combined.shape[1]),
"backbone_output": clip_bkbone.reshape(b,c,clip_bkbone.shape[1])
}'''
features,output_dict = self.backbone.forward(input_data)
(clipwise_output, norm_att, segmentwise_output) = self.attn(features)
#segmentwise_output = segmentwise_output.transpose(1, 2)
#framewise_output = interpolate(segmentwise_output,
# self.backbone.interpolate_ratio)
#framewise_output = pad_framewise_output(framewise_output, frames_num)
#print(output_dict["clipwise_output"].shape)
#print(clipwise_output.shape)
combined = torch.cat([output_dict["clipwise_output"],clipwise_output.unsqueeze(1)],axis=2)
output_dict["combined_output"] = combined.reshape(b,c,combined.shape[2])
return output_dict
Funny thing is, initially I had written a loop where I was accessing my backbone layers a little differently in the forward method. i.e instead of directly calling self.backbone.forward
, I was in a way rewriting the forward method of the backbone model i.e features = self.backbone.cnn_featureextractor
and so on and so forth operating on the features ahead. What I realized is while the weights were the same and the model was frozen, in this case I got a completely different output as compared to the above method. Just wanted to know the reason for that:
Code for the backbone model:
class PANNsDense121Att(nn.Module):
def __init__(self, sample_rate: int, window_size: int, hop_size: int,
mel_bins: int, fmin: int, fmax: int, classes_num: int, apply_aug: bool, top_db=None):
super().__init__()
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
self.interpolate_ratio = 32 # Downsampled ratio
self.apply_aug = apply_aug
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(
sr=sample_rate,
n_fft=window_size,
n_mels=mel_bins,
fmin=fmin,
fmax=fmax,
ref=ref,
amin=amin,
top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(
time_drop_width=64,
time_stripes_num=2,
freq_drop_width=8,
freq_stripes_num=2)
self.bn0 = nn.BatchNorm2d(mel_bins)
self.fc1 = nn.Linear(1024, 1024, bias=True)
self.att_block = AttBlock(1024, classes_num, activation='sigmoid')
self.densenet_features = models.densenet121(pretrained=False).features
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
def cnn_feature_extractor(self, x):
x = self.densenet_features(x)
return x
def preprocess(self, input_x, mixup_lambda=None):
x = self.spectrogram_extractor(input_x) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
frames_num = x.shape[2]
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.apply_aug:
x = self.spec_augmenter(x)
return x, frames_num
def forward(self, input_data):
input_x, mixup_lambda = input_data
"""
Input: (batch_size, data_length)"""
b, c, s = input_x.shape
input_x = input_x.reshape(b * c, s)
x, frames_num = self.preprocess(input_x, mixup_lambda=mixup_lambda)
if mixup_lambda is not None:
b = (b * c) // 2
c = 1
# Output shape (batch size, channels, time, frequency)
x = x.expand(x.shape[0], 3, x.shape[2], x.shape[3])
x = self.cnn_feature_extractor(x)
# Aggregate in frequency axis
x = torch.mean(x, dim=3)
x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = x.transpose(1, 2)
x = F.relu_(self.fc1(x))
x = x.transpose(1, 2)
x = F.dropout(x, p=0.5, training=self.training)
(clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
segmentwise_output = segmentwise_output.transpose(1, 2)
# Get framewise output
framewise_output = interpolate(segmentwise_output,
self.interpolate_ratio)
framewise_output = pad_framewise_output(framewise_output, frames_num)
frame_shape = framewise_output.shape
clip_shape = clipwise_output.shape
output_dict = {
'framewise_output': framewise_output.reshape(b, c, frame_shape[1], frame_shape[2]),
'clipwise_output': clipwise_output.reshape(b, c, clip_shape[1]),
}
return x,output_dict