RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float

I performed quantisation on some Pytorch model (in case, it’s HERE)

class QuantizedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.quant = torch.quantization.QuantStub()
        self.model = model
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

quantized_model = QuantizedModel(args.model)

# model must be set to eval mode for static quantization logic to work
quantized_model.eval()

backend = "qnnpack" # good for mobile cpus #fbgemm qnnpack x86
quantized_model.qconfig = torch.ao.quantization.get_default_qconfig(backend)

# the following prevents "AssertionError: Embedding quantization is only supported with float_qparams_weight_only_qconfig."
for _, mod in quantized_model.named_modules():
    if isinstance(mod, torch.nn.Embedding):
        mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

# insert observers
model_static_quantized = torch.quantization.prepare(quantized_model, inplace=False)
model_static_prepared = torch.quantization.convert(model_static_quantized, inplace=False)

# Convert the observed model to a quantized model.
model_static_convert = torch.ao.quantization.convert(model_static_prepared)
torch.save(model_static_convert.state_dict(), 'model_static_convert.pth')

However, if I run

summary(model_static_convert, input_size=torch.Size([1, 256, 192]))

I get the following error

RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Quantize: 1, Conv2d: 3, Linear: 3]

Same for model_static_convert, while I can still run summary on quantized_model. Any thoughts?

can you print the quantized model after convert? also why do you call torch.ao.quantization.convert multiple times?

Hi, you’re right. I’ll write the revised version below:

class QuantizedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.quant = torch.quantization.QuantStub()
        self.model = model
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

quantized_model = QuantizedModel(args.model)

# model must be set to eval mode for static quantization logic to work
quantized_model.eval()

backend = "qnnpack" # good for mobile cpus #fbgemm qnnpack x86
quantized_model.qconfig = torch.ao.quantization.get_default_qconfig(backend)

# the following prevents "AssertionError: Embedding quantization is only supported with float_qparams_weight_only_qconfig."
for _, mod in quantized_model.named_modules():
    if isinstance(mod, torch.nn.Embedding):
        mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

# insert observers
model_static_prepared = torch.quantization.prepare(quantized_model, inplace=False)

# Convert the observed model to a quantized model.
model_static_convert = torch.ao.quantization.convert(model_static_prepared, inplace=False)
torch.save(model_static_convert.state_dict(), 'model_static_convert.pth')

If I print model_static_convert , I get

QuantizedModel(
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (model): Model_SPEC2MIDI(
    (encoder_spec2midi): Encoder_SPEC2MIDI(
      (conv): QuantizedConv2d(1, 4, kernel_size=(1, 5), stride=(1, 1), scale=1.0, zero_point=0)
      (tok_embedding_freq): QuantizedLinear(in_features=244, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (pos_embedding_freq): QuantizedEmbedding(num_embeddings=256, embedding_dim=256, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
      (layers_freq): ModuleList(
        (0-2): 3 x EncoderLayer(
          (layer_norm): QuantizedLayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_k): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_v): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_o): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): QuantizedLinear(in_features=256, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_2): QuantizedLinear(in_features=512, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (dropout): QuantizedDropout(p=0.1, inplace=False)
        )
      )
      (dropout): QuantizedDropout(p=0.1, inplace=False)
    )
    (decoder_spec2midi): Decoder_SPEC2MIDI(
      (sigmoid): Sigmoid()
      (dropout): QuantizedDropout(p=0.1, inplace=False)
      (pos_embedding_freq): QuantizedEmbedding(num_embeddings=88, embedding_dim=256, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
      (layer_zero_freq): DecoderLayer_Zero(
        (layer_norm): QuantizedLayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (encoder_attention): MultiHeadAttentionLayer(
          (fc_q): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (fc_k): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (fc_v): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (fc_o): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (dropout): QuantizedDropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): QuantizedLinear(in_features=256, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (fc_2): QuantizedLinear(in_features=512, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (dropout): QuantizedDropout(p=0.1, inplace=False)
        )
        (dropout): QuantizedDropout(p=0.1, inplace=False)
      )
      (layers_freq): ModuleList(
        (0-1): 2 x DecoderLayer(
          (layer_norm): QuantizedLayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_k): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_v): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_o): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (encoder_attention): MultiHeadAttentionLayer(
            (fc_q): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_k): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_v): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_o): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): QuantizedLinear(in_features=256, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_2): QuantizedLinear(in_features=512, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (dropout): QuantizedDropout(p=0.1, inplace=False)
        )
      )
      (fc_onset_freq): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_offset_freq): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_mpe_freq): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_velocity_freq): QuantizedLinear(in_features=256, out_features=128, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (pos_embedding_time): QuantizedEmbedding(num_embeddings=128, embedding_dim=256, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
      (layers_time): ModuleList(
        (0-2): 3 x EncoderLayer(
          (layer_norm): QuantizedLayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_k): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_v): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_o): QuantizedLinear(in_features=256, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): QuantizedLinear(in_features=256, out_features=512, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (fc_2): QuantizedLinear(in_features=512, out_features=256, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
            (dropout): QuantizedDropout(p=0.1, inplace=False)
          )
          (dropout): QuantizedDropout(p=0.1, inplace=False)
        )
      )
      (fc_onset_time): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_offset_time): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_mpe_time): QuantizedLinear(in_features=256, out_features=1, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
      (fc_velocity_time): QuantizedLinear(in_features=256, out_features=128, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
    )
  )
  (dequant): DeQuantize()
)

For completeness, I’ll add model_static_prepared

QuantizedModel(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=2.0444393157958984e-05, max_val=0.9999867677688599)
  )
  (model): Model_SPEC2MIDI(
    (encoder_spec2midi): Encoder_SPEC2MIDI(
      (conv): Conv2d(
        1, 4, kernel_size=(1, 5), stride=(1, 1)
        (activation_post_process): HistogramObserver(min_val=-0.09601336717605591, max_val=0.3019489049911499)
      )
      (tok_embedding_freq): Linear(
        in_features=244, out_features=256, bias=True
        (activation_post_process): HistogramObserver(min_val=-0.37138867378234863, max_val=0.41479259729385376)
      )
      (pos_embedding_freq): Embedding(
        256, 256
        (activation_post_process): PlaceholderObserver(dtype=torch.float32, is_dynamic=False)
      )
      (layers_freq): ModuleList(
        (0): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-3.684764862060547, max_val=5.701371192932129)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-65.43849182128906, max_val=92.99006652832031)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-51.54948043823242, max_val=48.792999267578125)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.7751903533935547, max_val=4.762444496154785)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-2.080702066421509, max_val=2.2342758178710938)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-26.2632999420166, max_val=8.377985954284668)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-31.172016143798828, max_val=29.386695861816406)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-4.197020530700684, max_val=4.102721691131592)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-8.087146759033203, max_val=7.1710615158081055)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-7.2963480949401855, max_val=6.518189430236816)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.46532678604126, max_val=4.64891242980957)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.256741762161255, max_val=2.979290008544922)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-9.900541305541992, max_val=4.6360368728637695)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.176200866699219, max_val=9.425138473510742)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (2): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-2.353666305541992, max_val=1.580255150794983)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.574862957000732, max_val=4.150031566619873)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.890504837036133, max_val=5.468308925628662)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.074532985687256, max_val=4.4661641120910645)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.340092658996582, max_val=2.5051510334014893)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-5.517178535461426, max_val=1.998004674911499)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.395416736602783, max_val=4.252961158752441)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (decoder_spec2midi): Decoder_SPEC2MIDI(
      (sigmoid): Sigmoid()
      (dropout): Dropout(p=0.1, inplace=False)
      (pos_embedding_freq): Embedding(
        88, 256
        (activation_post_process): PlaceholderObserver(dtype=torch.float32, is_dynamic=False)
      )
      (layer_zero_freq): DecoderLayer_Zero(
        (layer_norm): LayerNorm(
          (256,), eps=1e-05, elementwise_affine=True
          (activation_post_process): HistogramObserver(min_val=-2.159541130065918, max_val=1.896041989326477)
        )
        (encoder_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(
            in_features=256, out_features=256, bias=True
            (activation_post_process): HistogramObserver(min_val=-5.465935707092285, max_val=5.7745466232299805)
          )
          (fc_k): Linear(
            in_features=256, out_features=256, bias=True
            (activation_post_process): HistogramObserver(min_val=-4.760313034057617, max_val=4.51210880279541)
          )
          (fc_v): Linear(
            in_features=256, out_features=256, bias=True
            (activation_post_process): HistogramObserver(min_val=-3.5816729068756104, max_val=3.889537811279297)
          )
          (fc_o): Linear(
            in_features=256, out_features=256, bias=True
            (activation_post_process): HistogramObserver(min_val=-11.872238159179688, max_val=10.520785331726074)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): Linear(
            in_features=256, out_features=512, bias=True
            (activation_post_process): HistogramObserver(min_val=-10.62162971496582, max_val=4.755774021148682)
          )
          (fc_2): Linear(
            in_features=512, out_features=256, bias=True
            (activation_post_process): HistogramObserver(min_val=-10.505728721618652, max_val=5.482919692993164)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layers_freq): ModuleList(
        (0): DecoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-4.589219570159912, max_val=2.5894887447357178)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.243453025817871, max_val=3.9806787967681885)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.46553897857666, max_val=4.249303817749023)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.0113227367401123, max_val=3.2259938716888428)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-1.4668912887573242, max_val=1.394325852394104)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-6.250119686126709, max_val=6.184148788452148)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.0483317375183105, max_val=2.854008436203003)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-2.5895557403564453, max_val=2.5737767219543457)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.731720447540283, max_val=3.3778984546661377)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.828614234924316, max_val=4.990065097808838)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-21.926610946655273, max_val=9.560783386230469)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): DecoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-6.104793548583984, max_val=4.21198034286499)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.9555583000183105, max_val=3.8071987628936768)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.650763034820557, max_val=4.730352401733398)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.7091028690338135, max_val=5.490683555603027)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.407787322998047, max_val=4.873568058013916)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-7.223546504974365, max_val=5.260581970214844)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.510692119598389, max_val=4.071597099304199)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-2.9346799850463867, max_val=2.7791829109191895)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.698692321777344, max_val=14.345036506652832)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-11.626481056213379, max_val=3.7872540950775146)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-5.328554153442383, max_val=4.779942512512207)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (fc_onset_freq): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-11.836112976074219, max_val=-8.34322738647461)
      )
      (fc_offset_freq): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-14.73990249633789, max_val=-10.910430908203125)
      )
      (fc_mpe_freq): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-6.061641693115234, max_val=-0.0050067780539393425)
      )
      (fc_velocity_freq): Linear(
        in_features=256, out_features=128, bias=True
        (activation_post_process): HistogramObserver(min_val=-27.80816078186035, max_val=8.943193435668945)
      )
      (pos_embedding_time): Embedding(
        128, 256
        (activation_post_process): PlaceholderObserver(dtype=torch.float32, is_dynamic=False)
      )
      (layers_time): ModuleList(
        (0): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-5.802577495574951, max_val=3.041590929031372)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-509.1664123535156, max_val=503.69482421875)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-239.112060546875, max_val=220.1811981201172)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.914785385131836, max_val=12.554680824279785)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-17.876018524169922, max_val=9.295068740844727)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-6.394947052001953, max_val=3.2832441329956055)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-12.565784454345703, max_val=12.579952239990234)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-5.003878116607666, max_val=3.5098938941955566)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.906929016113281, max_val=4.463215351104736)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-5.96990966796875, max_val=7.774462699890137)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-5.184789180755615, max_val=6.404623985290527)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.703141212463379, max_val=9.522540092468262)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-8.478245735168457, max_val=3.340111494064331)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.296571969985962, max_val=4.52815055847168)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (2): EncoderLayer(
          (layer_norm): LayerNorm(
            (256,), eps=1e-05, elementwise_affine=True
            (activation_post_process): HistogramObserver(min_val=-14.10239315032959, max_val=6.980414867401123)
          )
          (self_attention): MultiHeadAttentionLayer(
            (fc_q): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-4.058288097381592, max_val=3.779902219772339)
            )
            (fc_k): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-12.641512870788574, max_val=7.776096343994141)
            )
            (fc_v): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.134719848632812, max_val=7.545546531677246)
            )
            (fc_o): Linear(
              in_features=256, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-57.08567428588867, max_val=42.12849044799805)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (positionwise_feedforward): PositionwiseFeedforwardLayer(
            (fc_1): Linear(
              in_features=256, out_features=512, bias=True
              (activation_post_process): HistogramObserver(min_val=-10.095251083374023, max_val=2.2720048427581787)
            )
            (fc_2): Linear(
              in_features=512, out_features=256, bias=True
              (activation_post_process): HistogramObserver(min_val=-3.1800479888916016, max_val=3.1038708686828613)
            )
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (fc_onset_time): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-15.957779884338379, max_val=-9.783061027526855)
      )
      (fc_offset_time): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-17.46781349182129, max_val=-11.199623107910156)
      )
      (fc_mpe_time): Linear(
        in_features=256, out_features=1, bias=True
        (activation_post_process): HistogramObserver(min_val=-12.506057739257812, max_val=-0.1214350163936615)
      )
      (fc_velocity_time): Linear(
        in_features=256, out_features=128, bias=True
        (activation_post_process): HistogramObserver(min_val=-32.08559036254883, max_val=14.655937194824219)
      )
    )
  )
  (dequant): DeQuantStub()
)

OK thanks, I think there are two things that might go wrong, (1) do you need to put dequant stub and quant stub around sigmoid since it’s not quantized? (2) looks like you have a custom multihead attention layer defined, I think you’ll need to check if there is anything that can go wrong there as well

eager mode quantization is more of just provide building blocks and util functions instead of a flow I feel, since it requires people to reason about how to exactly quantize the model: Quantization — PyTorch main documentation

if your model can be captured by export I’d suggest to start with Quantization — PyTorch main documentation

Thanks for the response. I am reading all this stuff, but my other attempts so far are failing (for example I also tried FX graph mode quantisation).
(1) I don’t understand why sigmoid was not quantized.
(2) I am still not sure on how to exclude the MultiHeadAttentionLayer from being quantized.

The original error I got

RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Quantize: 1, Conv2d: 3, Linear: 3]

is related to line

spec_freq = self.dropout((spec_emb_freq * self.scale_freq) + self.pos_embedding_freq(pos_freq))

So, it seems the dropout layer which requires some sort of adjustment. Could you please give me some other hint? Thanks.

I also tried exporting a toy model, following the prototype on the page you linked.

from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import (
  prepare_qat_pt2e,
  convert_pt2e,
)

example_inputs = (torch.rand([20,7,5]),)
exported_model = capture_pre_autograd_graph(model, example_inputs)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
###prepared_model.load_state_dict(torch.load(checkpoint_path))

quantized_model = convert_pt2e(prepared_model)

# move certain ops like dropout to eval mode, equivalent to `m.eval()`
torch.ao.quantization.move_exported_model_to_eval(quantized_model)

summary(quantized_model, input_data=torch.rand([20,7,5]))

but I get

NotImplementedError: Calling eval() is not supported yet.

The model is

class Simple1DCNN(torch.nn.Module):
    def __init__(self):
        super(Simple1DCNN, self).__init__()
        self.layer1 = torch.nn.Conv1d(in_channels=7, out_channels=20, kernel_size=5, stride=2)
        self.act1 = torch.nn.ReLU()
        self.layer2 = torch.nn.Conv1d(in_channels=20, out_channels=500000, kernel_size=1)
    def forward(self, x):
        x = self.layer1(x)
        x = self.act1(x)
        x = self.layer2(x)
        return x
    
# Instantiate the model
model = Simple1DCNN()

one thing you can try is to set the qconfig for the MultiHeadAttentionLayer to be None, you can do this by: model.mult_head_attention_module.qconfig = None (note: mult_head_attention_module is a placeholder for the fully qualified name to the MultiHeadAttentionLayer module)

did you call model.eval() in summary? you can remove that you change it to this torch.ao.quantization.move_exported_model_to_eval(quantized_model)

I tried retrieving the modules names using

named_layers = dict(args.model.named_modules())

but it does not seem to work with the command you suggested

model.mult_head_attention_module.qconfig = None

No. In the code above I did actually use torch.ao.quantization.move_exported_model_to_eval(quantized_model), but that somehow seems to be linked to model.eval().

can you show the complete code?

the error message is from calling .eval on an exported model: pytorch/torch/ao/quantization/pt2e/utils.py at main · pytorch/pytorch · GitHub, so there must be a place that you are calling this, can you print the stack trace? where does the error happen?

I believe I managed to set the MultiHeadAttention quantization to None

backend = "qnnpack" # good for mobile cpus #fbgemm qnnpack x86

quantized_model.qconfig = torch.ao.quantization.get_default_qconfig(backend)

for _, mod in quantized_model.named_modules():

    if isinstance(mod, torch.nn.Embedding):
        mod.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

# exclude MultiHeadAttentionLayer from quantization
for i in range(3):
    quantized_model.model.encoder_spec2midi.layers_freq[i].self_attention.qconfig = None
    quantized_model.model.decoder_spec2midi.layers_time[i].self_attention.qconfig = None

for i in range(2):
    quantized_model.model.decoder_spec2midi.layers_freq[i].self_attention.qconfig = None

quantized_model.model.decoder_spec2midi.layer_zero_freq.encoder_attention.qconfig = None
quantized_model.model.decoder_spec2midi.layers_freq[i].encoder_attention.qconfig = None

but I got the same error as above

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Quantize: 1, Conv2d: 3, Linear: 3]

However, by adding

quantized_model.model.encoder_spec2midi.tok_embedding_freq.qconfig = None
quantized_model.model.encoder_spec2midi.conv.qconfig = None
quantized_model.quant.qconfig = None

I now get

RuntimeError: promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [QuantStub: 1, Conv2d: 3, Linear: 3, Embedding: 3, Dropout: 3, MultiHeadAttentionLayer: 5, Linear: 6, Linear: 6, Linear: 6, Dropout: 6, Linear: 6, Dropout: 5]

I think you’ll have a better chance to quantize this with XNNPACKQuantizer