Hi! My quantized segmentation model is slowing down on android , as well as on desktop cpu. Can’t figure out why. I tried different way to quantize, used different layers. Here I attached a full example with QNNPACK
For android I use. Also I tried different versions here
implementation 'org.pytorch:pytorch_android:1.5.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.5.0-SNAPSHOT'
import torch
import os
from torch import nn
from torchvision.models.resnet import BasicBlock, ResNet
from torch.quantization import fuse_modules
from torch.nn import functional as F
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class QuantizableBasicBlock(BasicBlock):
def __init__(self, *args, **kwargs):
super(QuantizableBasicBlock, self).__init__(*args, **kwargs)
self.add_relu = torch.nn.quantized.FloatFunctional()
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add_relu.add_relu(out, identity)
return out
def fuse_model(self):
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
shapes = (16, 32, 64, 128)
block = QuantizableBasicBlock
self.inplanes = shapes[0]
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False, dilation=1)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, shapes[0], layers[0])
self.layer2 = self._make_layer(block, shapes[1], layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, shapes[2], layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, shapes[3], layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _replace_relu(module):
reassign = {}
for name, mod in module.named_children():
_replace_relu(mod)
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False)
for key, value in reassign.items():
module._modules[key] = value
class QuantizableResNet(ResNet):
def __init__(self, *args, **kwargs):
super(QuantizableResNet, self).__init__(*args, **kwargs)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self):
fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
for m in self.modules():
if isinstance(m, (QuantizableBasicBlock, DecoderBlock)):
m.fuse_model()
class Conv2dReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
stride=1, use_batchnorm=True, **batchnorm_params):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=not (use_batchnorm), groups=1),
nn.ReLU(),
]
if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
def fuse_model(self):
fuse_modules(self, ['block.0', 'block.1', 'block.2'], inplace=True)
for m in self.modules():
if type(m) == QuantizableBasicBlock:
m.fuse_model()
class Up(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return F.upsample_bilinear(x, scale_factor=2)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=True),
Up(),
Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=True),
)
def forward(self, x):
return self.block(x)
def fuse_model(self):
for m in self.modules():
if isinstance(m, Conv2dReLU):
m.fuse_model()
class ResNetUnet(QuantizableResNet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pretrained = False
in_channels = (128, 64, 32, 16, 16)
prefinal_channels = 16
final_channels = 3
self.block1 = DecoderBlock(in_channels[0], in_channels[1])
self.block2 = DecoderBlock(in_channels[1], in_channels[2])
self.block3 = DecoderBlock(in_channels[2], in_channels[3])
self.block4 = DecoderBlock(in_channels[3], in_channels[4])
self.block5 = DecoderBlock(in_channels[4], prefinal_channels)
self.final_conv = nn.Conv2d(prefinal_channels, final_channels, kernel_size=(1, 1))
self.linear = nn.Linear(128, 10)
self.pool = nn.AdaptiveAvgPool2d(1)
self.add = torch.nn.quantized.FloatFunctional()
del self.fc
def forward(self, x):
x = self.quant(x)
x, cls = self._forward_impl(x)
x = self.dequant(x)
cls = self.dequant(cls)
return x, cls
def _forward_impl(self, x):
x0 = self.conv1(x)
x0 = self.bn1(x0)
x0 = self.relu(x0)
x1 = self.maxpool(x0)
x1 = self.layer1(x1)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
cls_out = self.linear(self.pool(x4).view(x4.size(0), -1))
x4 = self.block1(x4)
x3 = self.add.add(x3, x4)
x3 = self.block2(x3)
x2 = self.add.add(x2, x3)
x2 = self.block3(x2)
x1 = self.add.add(x1, x2)
x1 = self.block4(x1)
x0 = self.add.add(x0, x1)
x0 = self.block5(x0)
x0 = self.final_conv(x0)
return x0, cls_out
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p") / 1e6)
os.remove('temp.p')
def get_model():
model = ResNetUnet(QuantizableBasicBlock, [2, 2, 2, 2])
_replace_relu(model)
return model
def quantize_model(model, backend):
_dummy_input_data = torch.rand(1, 3, 224, 224)
if backend not in torch.backends.quantized.supported_engines:
raise RuntimeError("Quantized backend not supported ")
torch.backends.quantized.engine = backend
model.eval()
if backend == 'fbgemm':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer)
elif backend == 'qnnpack':
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_weight_observer)
model.fuse_model()
torch.quantization.prepare(model, inplace=True)
model(_dummy_input_data)
torch.quantization.convert(model, inplace=True)
def evaltime(model):
from time import time
res = 0
n = 500
for _ in range(n):
t = time()
with torch.no_grad():
model(torch.ones(1, 3, 224, 224))
res += time() - t
return res / n
torch.set_num_threads(1)
model = get_model()
model.eval()
print('time/image, initial model', evaltime(model)) # 0.025
print_size_of_model(model) # Size (MB): 2.89155
quantize_model(model, 'qnnpack')
print('time/image, quantized model', evaltime(model)) # 0.101
print_size_of_model(model) # Size (MB): 0.740964