`torch.mean(keepdim=True)` with `quint8` tensor does not work with qnnpack backend

I’m new to quantization so I couldn’t figure out a way to easily reproduce this without going through the whole flow.

From pdb during a forward pass of a quantized model:

print(x.dtype)
# >> torch.quint8
print(x.shape)
# >> torch.Size([1, 40, 64, 384])
print(x.mean((2,3), keepdim=True).shape)
# >> torch.Size([1, 40])

This happens when I run the forward pass just after setting torch.backends.quantized.engine = 'qnnpack'.

If I do not set it, the forward pass runs fine, and 10x faster than the non-quantized version of my model (in other words, as expected)

Running this on Android causes the same issue.

1 Like

I’m unable to reproduce:

import torch
import copy


# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.mean((2, 3), keepdim=True)
        x = self.dequant(x)
        return x


model_fp32 = M()
model_fp32.train()
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
                                                   [['conv', 'bn', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
x = torch.rand(2, 1, 224, 224)
model_fp32_prepared.eval()
model_fp32_prepared(x)
model_int8 = torch.quantization.convert(model_fp32_prepared)

y = model_int8(x)
print(y)
print(y.shape)
tensor([[[[0.]]],
        [[[0.]]]])
torch.Size([2, 1, 1, 1])

Do you have a better repro?

Hello Alexander,

I can confirm this and took the liberty to file QNNPACK mean with keepdim doesn't work · Issue #58668 · pytorch/pytorch · GitHub . Thank you so much for reporting this with very precise repro information! This makes things much easier.

For reference: This illustrates the problem:

torch.backends.quantized.engine = 'qnnpack'
print(torch.backends.quantized.engine, torch.quantize_per_tensor(torch.randn(5, 5, 5, 5), scale=0.2, zero_point=0, dtype=torch.quint8).mean((2,3), keepdim=True).shape)
torch.backends.quantized.engine = 'fbgemm'
print(torch.backends.quantized.engine, torch.quantize_per_tensor(torch.randn(5, 5, 5, 5), scale=0.2, zero_point=0, dtype=torch.quint8).mean((2,3), keepdim=True).shape)

Best regards

Thomas

1 Like

I’m happy to report that the issue linked above has been closed, so we should see nightlies that have the problem fixed. I don’t think it made it into 1.9, though, but I hope to make Raspberry Pi wheels with the fix soon enough.

Best regards

Thomas

1 Like