Quantized model is slow and gpu usage becomes high with qnnpack

Hi all

Recently, I am trying to deploy my model to android. I build libtorch for android with NDK r19c. My model consists of Linear, relu and layernorm. There are some cat ops in forward. I found that, my model becomes slower when I quantized it and the cpu usage also becomes high. Although my device is arch64, I built arm32 libtorch.so to be capable with arm32 devices. Do you think why my model becomes slower after quantization? Does the shape of input and weights affect QNNPACK?

I think qnnpack should be efficient on arm, so I don’t know why this happens.
The flowing is my code:

class Conv1dSubsampling(torch.nn.Module):

def init(self, idim, odim, groups=1, use_bias=False, do_quant=False, svd_dim=-1):

“”“Construct an Conv1dSubsampling object.”""
super(Conv1dSubsampling, self).init()
self.relu = torch.nn.ReLU()
self.norm = torch.nn.LayerNorm(odim)
self.quant1 = QuantStub()
self.quant2 = QuantStub()
self.dequant1 = DeQuantStub()
self.dequant2 = DeQuantStub()
self.cnn_linear1 = torch.nn.Sequential(
torch.nn.Linear(5idim, 256, bias=use_bias),
self.cnn_linear2 = torch.nn.Sequential(
3, svd_dim, bias=use_bias),
torch.nn.Linear(svd_dim, 256, bias=use_bias),
self.cnn_linear3 = torch.nn.Sequential(
torch.nn.Linear(256, svd_dim, bias=use_bias),
torch.nn.Linear(svd_dim, odim, bias=use_bias),

def forward(self, x):

#slice and cat the input
x0 = x[:-4, :]
x1 = x[1:-3,:]
x2 = x[2:-2,:]
x3 = x[3:-1,:]
x4 = x[4:,:]
x = torch.cat((x0,x1,x2,x3,x4), dim=-1)
x = x[::2, :]
x = torch.cat((x, torch.zeros(16-x.shape[0], x.shape[1])))
x = self.quant1(x)
x = self.cnn_linear1(x) # (t//2, 256)
x = self.dequant1(x)
x0 = x[0:-2, :]
x1 = x[1:-1, :]
x2 = x[2:, :]
x = torch.cat((x0,x1,x2), dim=-1)
x = x[::2, :]
#x = x[0:4, :]
x = torch.cat((x, torch.zeros(8-x.shape[0], x.shape[1])))
x = self.quant2(x)
x = self.cnn_linear2(x) # (1, t//2, 256)
x = self.cnn_linear3(x)
x = self.norm(x)
x = self.dequant2(x)
return x[0:4,:]

model = Conv1dSubsampling(40, 400, do_quant=True)


if do_quant:
qconfig = torch.quantization.get_default_qconfig(‘qnnpack’)
model.qconfig = qconfig
torch.backends.quantized.engine = ‘qnnpack’
torch.quantization.prepare(model, inplace=True)

for i in range(1, 100):
    x = model(torch.randn(batch, idim))
torch.quantization.convert(model, inplace=True)

traced_module = torch.jit.trace(model, (torch.randn(batch, idim)))
script_module = torch.jit.script(model)

Finally, I load it using C++ and run forward.

Hi @Sining_Sun, quantized LayerNorm currently has an efficient kernel in fbgemm (x86), but it does not have an efficient kernel in qnnpack (ARM). So, you are likely seeing the slow fallback path of the kernel on ARM.

A workaround for now could be to let LayerNorm stay in fp32. You can do this by setting the qconfig to None for the LayerNorm module, and moving the dequant to be before LayerNorm.


Thanks. Apart from the layernorm problem, I found another problem. Even my network is very simple, for example, just one Linear layer without LayerNorm, the cpu usage is very high after quantization. More details can be found in in post.

This problem has been confused me for a long time.