import torch
import copy
# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
# QuantStub converts tensors from floating point to quantized
self.quant = torch.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
self.relu = torch.nn.ReLU()
# DeQuantStub converts tensors from quantized to floating point
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = x.mean((2, 3), keepdim=True)
x = self.dequant(x)
return x
model_fp32 = M()
model_fp32.train()
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
model_fp32_fused = torch.quantization.fuse_modules(model_fp32,
[['conv', 'bn', 'relu']])
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
x = torch.rand(2, 1, 224, 224)
model_fp32_prepared.eval()
model_fp32_prepared(x)
model_int8 = torch.quantization.convert(model_fp32_prepared)
y = model_int8(x)
print(y)
print(y.shape)
I’m happy to report that the issue linked above has been closed, so we should see nightlies that have the problem fixed. I don’t think it made it into 1.9, though, but I hope to make Raspberry Pi wheels with the fix soon enough.