I have following model, which I want to run on Android.
class depthwise_separable_conv(nn.Module):
def __init__(self, nin, nout, kernel_size, kernels_per_layer=1):
super(depthwise_separable_conv, self).__init__()
self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=1, groups=nin)
self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
out = self.relu(out)
return out
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = depthwise_separable_conv(1, 6, 5)
self.conv2 = depthwise_separable_conv(6, 16, 5)
self.conv3 = depthwise_separable_conv(16, 32, 5)
self.pool = nn.AvgPool2d(2, 2)
self.lrn = nn.LocalResponseNorm(2)
self.fc1 = nn.Linear(32 * 6 * 13, 250)
self.relu1 = nn.ReLU(inplace=False)
self.fc2 = nn.Linear(250, 84)
self.relu2 = nn.ReLU(inplace=False)
self.fc3 = nn.Linear(84, 2)
self.soft = nn.Softmax(dim=1)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.pool((self.conv1(x)))
x = self.pool((self.conv2(x)))
x = self.pool((self.conv3(x)))
x = self.dequant(x)
x = self.lrn(x)
x = self.quant(x)
x = x.reshape(-1, 32 * 6 * 13)
x = self.relu1(self.fc1(x))
x = self.relu2(self.fc2(x))
x = self.fc3(x)
x = self.dequant(x)
x = self.soft(x)
return x
I created two versions, with and without Quantization(this model doesnât have quant() and dequant() parts).
I performed quantization using the following code
backend = "qnnpack"
qconfig = torch.quantization.get_default_qconfig(backend)
net.qconfig = qconfig
torch.backends.quantized.engine = backend
qconfig_dict = {"": qconfig}
quant_net = net
quant_net = prepare_fx(quant_net, qconfig_dict)
quant_net(torch.Tensor(batch)) #calibrate
quant_net = convert_fx(quant_net)
and scripted both models using
traced_script_module = torch.jit.script(quant_net)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(MODEL_DIR + "stQuant_lite.ptl")
In python, Iâm able to see inference time reduction using the quantization. But, reverse happens on Android. Moreover, in Android, the RAM usage is also more in the case of quantized model.
Also, the weirdest thing is happening in Android. If I rename the scripted model âstQuant_lite.ptlâ to âstQuant_lite_11.ptlâ, keeping everything else same(I literally just use refactor->rename), I get the following error :
Could not run âquantized::conv2d.newâ with arguments from the âCPUâ backend.
Iâm using torch 1.9.0 in python on a linux OS and pytorch_lite:1.9.0 in Android. My app is almost the HelloWorldApp, except that i feed an empty FloatBuffer to the model instead of an Image.
Any help is greatly appreciated.