Serialize quantized model with external libraries with TorchScript

Hi. I have an audio processing model that uses functions from TorchLibrosa (GitHub - qiuqiangkong/torchlibrosa) that I want to quantize and save. I get the following error:

RuntimeError:

_pad(Tensor input, int[] pad, str mode=“constant”, float value=0.) → (Tensor):
Expected a value of type ‘float’ for argument ‘value’ but instead found type ‘int’.
:
input_1 = input
getitem = input_1[(slice(None, None, None), None, slice(None, None, None))]; input_1 = None
_pad_1 = torch.nn.functional._pad(getitem, (256, 256), mode = ‘reflect’, value = 0); getitem = None
~~~~~~~~~~~~~~~~~~~~~~~~ <— HERE
base_spectrogram_extractor_stft_conv_real_input_scale_0 = self.base_spectrogram_extractor_stft_conv_real_input_scale_0
base_spectrogram_extractor_stft_conv_real_input_zero_point_0 = self.base_spectrogram_extractor_stft_conv_real_input_zero_point_0

I tried ignoring these modules using the following config for quantization in FX Graph mode:

qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

prep_config_dict = {
    "non_traceable_module_name": ["spectrogram_extractor", "logmel_extractor"]
}
prepared_model = prepare_fx(
    model_to_quantize, qconfig_dict, prepare_custom_config_dict=prep_config_dict)

What can I do to fix this issue? I really appreciate any help you can provide.

Can you provide some additional context? Is it only the serialization thats the issue, can you run the quantized model without saving it…etc?

ideally a way to reproduce the error would be most helpful.

Thanks for the reply. It is only the serialization that is the issue. I can perform a forward pass on the quantized model.

Following, I have a minimal reproducible example of the issue:

import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import convert_fx, prepare_fx
from torchlibrosa.stft import LogmelFilterBank, Spectrogram


class AudioModel(nn.Module):
    def __init__(self, classes_num):

        super(AudioModel, self).__init__()

        self.spectrogram_extractor = Spectrogram(n_fft=512, hop_length=160,
                                                 win_length=512, window='hann', center=True, pad_mode='reflect',
                                                 freeze_parameters=True)
        self.logmel_extractor = LogmelFilterBank(sr=16000, n_fft=512,
                                                 n_mels=64, fmin=50, fmax=8000, ref=1.0, amin=1e-10, top_db=None,
                                                 freeze_parameters=True)
        self.bn0 = nn.BatchNorm2d(64)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(
            3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64, 64, bias=True)
        self.fc_audioset = nn.Linear(64, classes_num, bias=True)

    def forward(self, input):
        # (batch_size, 1, time_steps, freq_bins)
        x = self.spectrogram_extractor(input)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.avg_pool2d(x, kernel_size=(2, 2))
        x = F.dropout(x, p=0.2, training=self.training)

        x = torch.mean(x, dim=3)

        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset(x))

        output_dict = {'clipwise_output': clipwise_output,
                       'embedding': embedding}

        return output_dict


float_model = AudioModel(classes_num=21)

model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()

qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

prep_config_dict = {
    "non_traceable_module_name": ["spectrogram_extractor", "logmel_extractor"]
}
prepared_model = prepare_fx(
    model_to_quantize, qconfig_dict, prepare_custom_config_dict=prep_config_dict)


def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for X, _ in data_loader:
            model(X)


dummy_input = [[torch.rand(1, 30000), 0]]
calibrate(prepared_model, dummy_input)  # run calibration on sample data

quantized_model = convert_fx(prepared_model)

params = sum([np.prod(p.size()) for p in float_model.parameters()])
# print("Number of Parameters: {:.1f}M".format(params/1e6))
print(f"Number of Parameters: {params}M")
params = sum([np.prod(p.size()) for p in quantized_model.parameters()])
# print("Number of Parameters: {:.1f}M".format(params/1e6))
print(f"Number of Parameters: {params}M")

quantized_model(dummy_input[0][0])

torch.jit.save(torch.jit.script(quantized_model),
               'test.pth')
loaded_quantized = torch.jit.load('test.pth')

I also tried with Eager mode quantization, dequantizing for arithmetic operations and the Spectrograms, but the problem is still while trying to save.

Hey, I haven’t had a chance to give this an in depth look, if I knew which line was causing the issue it’d be easier, but I suspect there’s an issue with the output of something being an integer while an observer is expecting a float, or something along those lines. If you try casting the output of the torch librosa functions to float, does that change anyhting?

Casting to float does not solve the problem.

Did you have the time to take a look into this issue?

Managed to fix the problem by not using the following line:
x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)