Thanks for the reply. It is only the serialization that is the issue. I can perform a forward pass on the quantized model.
Following, I have a minimal reproducible example of the issue:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import convert_fx, prepare_fx
from torchlibrosa.stft import LogmelFilterBank, Spectrogram
class AudioModel(nn.Module):
def __init__(self, classes_num):
super(AudioModel, self).__init__()
self.spectrogram_extractor = Spectrogram(n_fft=512, hop_length=160,
win_length=512, window='hann', center=True, pad_mode='reflect',
freeze_parameters=True)
self.logmel_extractor = LogmelFilterBank(sr=16000, n_fft=512,
n_mels=64, fmin=50, fmax=8000, ref=1.0, amin=1e-10, top_db=None,
freeze_parameters=True)
self.bn0 = nn.BatchNorm2d(64)
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(
3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(64, 64, bias=True)
self.fc_audioset = nn.Linear(64, classes_num, bias=True)
def forward(self, input):
# (batch_size, 1, time_steps, freq_bins)
x = self.spectrogram_extractor(input)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = F.relu_(self.bn1(self.conv1(x)))
x = F.avg_pool2d(x, kernel_size=(2, 2))
x = F.dropout(x, p=0.2, training=self.training)
x = torch.mean(x, dim=3)
(x1, _) = torch.max(x, dim=2)
x2 = torch.mean(x, dim=2)
x = x1 + x2
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
embedding = F.dropout(x, p=0.5, training=self.training)
clipwise_output = torch.sigmoid(self.fc_audioset(x))
output_dict = {'clipwise_output': clipwise_output,
'embedding': embedding}
return output_dict
float_model = AudioModel(classes_num=21)
model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
prep_config_dict = {
"non_traceable_module_name": ["spectrogram_extractor", "logmel_extractor"]
}
prepared_model = prepare_fx(
model_to_quantize, qconfig_dict, prepare_custom_config_dict=prep_config_dict)
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for X, _ in data_loader:
model(X)
dummy_input = [[torch.rand(1, 30000), 0]]
calibrate(prepared_model, dummy_input) # run calibration on sample data
quantized_model = convert_fx(prepared_model)
params = sum([np.prod(p.size()) for p in float_model.parameters()])
# print("Number of Parameters: {:.1f}M".format(params/1e6))
print(f"Number of Parameters: {params}M")
params = sum([np.prod(p.size()) for p in quantized_model.parameters()])
# print("Number of Parameters: {:.1f}M".format(params/1e6))
print(f"Number of Parameters: {params}M")
quantized_model(dummy_input[0][0])
torch.jit.save(torch.jit.script(quantized_model),
'test.pth')
loaded_quantized = torch.jit.load('test.pth')
I also tried with Eager mode quantization, dequantizing for arithmetic operations and the Spectrograms, but the problem is still while trying to save.