Questions on QAT for Wav2Vec

Hello,

I am currently facing an issue while trying to apply QAT to the pre-trained model retrieved through:

torchaudio.pipelines.WAV2VEC2_ASR_BASE_100H

At first I want to only apply QAT sequentially on the attention layers in the encoder, and then when successful apply it as well to the Conv layers in the feature extractor. I do not touch the Position Embedding part nor the LayerNorm. The purpose is to measure the difference between different configuration.

I am using this for each modules corresponding to a linear layer:

module.qconfig = torch.quantization.QConfig(
        activation=torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver,
                                                             dtype=torch.qint8),
        weight=torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver,
                                                         dtype=torch.qint8, qscheme=torch.per_tensor_affine))

The others qconfig are set to None.

And then on the whole model:

torch.quantization.prepare_qat(model, inplace=True)

There is no issue with training, and I did QAT both sequentially or on all the linear layers at the same time.

I am then saving the model using:

torch.save(torch.quantization.convert(model.to("cpu").eval(), inplace=False).state_dict(), model_save_path)

I can load the quantized model, and when printing it, it looks as expected.

My issue arises when I am running a validation on the quantized model. The issue is with the forward pass. I tried this for example:

torch.quantization.QuantStub()
torch.quantization.DeQuantStub()
import torch.quantization as quant

class QuantizedWav2VecModel(nn.Module):

    def __init__(self, pretrained_model):
        super(QuantizedWav2VecModel, self).__init__()
        self.pretrained_model = pretrained_model
        self.apply_quant = quant.QuantStub()
        self.apply_dequant = quant.DeQuantStub()

    def forward(self, input_values, input_lengths):
        quantized_input = self.apply_quant(input_values)
        output, output_size = self.pretrained_model(quantized_input, input_lengths)
        dequantized_output = self.apply_dequant(output)
        return dequantized_output, output_size

As some of my layers are not quantized, if I apply QuantStub then I run in this type of error:

NotImplementedError: Could not run 'aten::_slow_conv2d_forward' with arguments from the 'QuantizedCPU' backend.

I tried to change my forward pass, to check for each module and its childrens if I should apply QuantStub and DeQuantStub, but then I also need to keep in mind what goes into the forward pass of Wav2Vec, as some layers do not need the input lengths. This raise other errors as well.

My questions are:

1- Is there an other way to use QAT on this pretrained model, rather than the one I am using ? I would like to be able to apply QAT to whatever layer or group of layers I want.

2- Is the cleanest way to reimplement Wav2Vec and then load the weights from the pre-trained model ?

Note: I didn’t use torch.ao.quantization as I had the following error with it:

AssertionError: nnq.Conv1d.from_float only works for Conv1d but got:<class 'torch.nn.utils.parametrize.ParametrizedConv1d'> in the method _ConvNd.from_float()

Thank you.