Quantization causing reduced performance on pytorch Android

I have following model, which I want to run on Android.

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, kernel_size, kernels_per_layer=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=1, groups=nin)
        self.pointwise = nn.Conv2d(nin * kernels_per_layer, nout, kernel_size=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        out = self.relu(out)
        return out
    
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = depthwise_separable_conv(1, 6, 5)
        self.conv2 = depthwise_separable_conv(6, 16, 5)
        self.conv3 = depthwise_separable_conv(16, 32, 5)
        self.pool = nn.AvgPool2d(2, 2)
        self.lrn = nn.LocalResponseNorm(2)
        self.fc1 = nn.Linear(32 * 6 * 13, 250)
        self.relu1 = nn.ReLU(inplace=False)
        self.fc2 = nn.Linear(250, 84)
        self.relu2 = nn.ReLU(inplace=False)
        self.fc3 = nn.Linear(84, 2)
        self.soft = nn.Softmax(dim=1)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.pool((self.conv1(x)))
        x = self.pool((self.conv2(x)))
        x = self.pool((self.conv3(x)))
        x = self.dequant(x)
        x = self.lrn(x)
        x = self.quant(x)
        x = x.reshape(-1, 32 * 6 * 13)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        x = self.soft(x)
        return x

I created two versions, with and without Quantization(this model doesn’t have quant() and dequant() parts).

I performed quantization using the following code

backend = "qnnpack"
qconfig = torch.quantization.get_default_qconfig(backend)
net.qconfig = qconfig
torch.backends.quantized.engine = backend

qconfig_dict = {"": qconfig}
quant_net = net
quant_net = prepare_fx(quant_net, qconfig_dict)
quant_net(torch.Tensor(batch))  #calibrate 
quant_net = convert_fx(quant_net)

and scripted both models using

traced_script_module = torch.jit.script(quant_net)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(MODEL_DIR + "stQuant_lite.ptl")

In python, I’m able to see inference time reduction using the quantization. But, reverse happens on Android. Moreover, in Android, the RAM usage is also more in the case of quantized model.

Also, the weirdest thing is happening in Android. If I rename the scripted model “stQuant_lite.ptl” to “stQuant_lite_11.ptl”, keeping everything else same(I literally just use refactor->rename), I get the following error :

Could not run ‘quantized::conv2d.new’ with arguments from the ‘CPU’ backend.

Similarly, when I used a model named ‘lite.ptl’, I even got wrong output shape. However, when I renamed it to “_lite.plt”, it gave expected output.

I’m using torch 1.9.0 in python on a linux OS and pytorch_lite:1.9.0 in Android. My app is almost the HelloWorldApp, except that i feed an empty/zero_initialized FloatBuffer to the model instead of an Image.

Any help is greatly appreciated.

Regarding filename differences, it sounds like spooky action at distance. Super weird. But generally when you get this error “Could not run ‘quantized::conv2d.new’ with arguments from the ‘CPU’ backend.” it means that your quantized::conv2d op is getting float tensor as input.

Regarding runtime on android: are you comparing fp32 model runtime on android vs. quantized model runtime on android?

Regarding memory footprint, are you talking about peak memory or you are talking about average memory utilization? And what is the difference compared to fp32 model?

yeah, I saw that the same comment at multiple places, but the code runs using uint8, I’ve verified in python.

Yes, a lite version of both

I run a loop of 100 iterations to get better measurement of time consumption, Hence, peak and average are almost same(I guess, garbage collector can’t work fast enough to execute between iterations). I’m seeing around 100MB more memory usage for quantized model.

Is it possible for you to do print(model.graph) in some python file/shell and paste the output here?

graph(%self : torch.___torch_mangle_2997.Net,
%x.1 : Tensor):
%2 : int[] = prim::Constantvalue=[1, 1, 1]
%3 : int[] = prim::Constantvalue=[2, 1, 1]
%4 : int[] = prim::Constantvalue=[0, 0, 0, 0, 1, 0]
%5 : int[] = prim::Constantvalue=[1, 1]
%6 : int[] = prim::Constantvalue=[2, 1]
%7 : int[] = prim::Constantvalue=[0, 0, 1, 0]
%8 : float = prim::Constantvalue=1. # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/normalization.py:56:37
%9 : float = prim::Constantvalue=0.75 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/normalization.py:55:67
%10 : float = prim::Constantvalue=0.0001 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/normalization.py:55:55
%11 : str = prim::Constantvalue=“Expected 3D or higher dimensionality input (got {} dimensions)” # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2374:51
%12 : int[] = prim::Constantvalue=[0, 0, 0]
%13 : int[] = prim::Constantvalue=[0, 0]
%14 : int[] = prim::Constantvalue=[2, 2]
%15 : str = prim::Constantvalue=“AssertionError: Padding length too large”
%16 : int = prim::Constantvalue=3 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4156:26
%17 : NoneType = prim::Constant()
%18 : bool = prim::Constantvalue=0
%19 : bool = prim::Constantvalue=1 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:28:4
%20 : int = prim::Constantvalue=2 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:29:23
%21 : int = prim::Constantvalue=1 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:30:70
%22 : int = prim::Constantvalue=4 # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:31
%23 : str = prim::Constantvalue=“Input shape must be (N, C, H, W)!” # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:421:29
%24 : float = prim::Constantvalue=0.
%25 : float = prim::Constantvalue=0.075978539884090424
%26 : int = prim::Constantvalue=-1 # :40:30
%self.fc3.zero_point : int = prim::Constantvalue=81
%self.fc3.scale : float = prim::Constantvalue=0.092987142503261566
%self.fc3._packed_params._packed_params : torch.torch.classes.quantized.LinearPackedParamsBase = prim::Constantvalue=object(0x4ccd2080)
%self.fc2.scale : float = prim::Constantvalue=0.053463395684957504
%self.fc2._packed_params._packed_params : torch.torch.classes.quantized.LinearPackedParamsBase = prim::Constantvalue=object(0x3d5bff60)
%self.fc1.scale : float = prim::Constantvalue=0.14854238927364349
%self.fc1._packed_params._packed_params : torch.torch.classes.quantized.LinearPackedParamsBase = prim::Constantvalue=object(0x4bf6a5b0)
%self.conv3.pointwise.scale : float = prim::Constantvalue=0.10303349047899246
%self.conv3.pointwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d5c1230)
%self.conv3.depthwise.zero_point : int = prim::Constantvalue=168
%self.conv3.depthwise.scale : float = prim::Constantvalue=0.48179984092712402
%self.conv3.depthwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d5bcb60)
%self.conv2.pointwise.scale : float = prim::Constantvalue=0.40423935651779175
%self.conv2.pointwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d5c5a00)
%self.conv2.depthwise.zero_point : int = prim::Constantvalue=205
%self.conv2.depthwise.scale : float = prim::Constantvalue=0.69283735752105713
%self.conv2.depthwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d5bd8e0)
%self.conv1.pointwise.scale : float = prim::Constantvalue=0.12927700579166412
%self.conv1.pointwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d66a660)
%self.conv1.depthwise.zero_point : int = prim::Constantvalue=0
%self.conv1.depthwise.scale : float = prim::Constantvalue=0.17992280423641205
%self.conv1.depthwise._packed_params : torch.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constantvalue=object(0x3d70a6d0)
%self.quant.dtype : int = prim::Constantvalue=13
%x.5 : Tensor = aten::quantize_per_tensor(%x.1, %25, %self.conv1.depthwise.zero_point, %self.quant.dtype) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/init.py:52:15
%51 : int[] = aten::size(%x.5) # :7:9
%52 : int = aten::len(%51) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
%53 : bool = aten::ne(%52, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
= prim::If(%53) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:421:12
→ ()
block1():
→ ()
%out.17 : Tensor = quantized::conv2d(%x.5, %self.conv1.depthwise._packed_params, %self.conv1.depthwise.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:426:15
%55 : int[] = aten::size(%out.17) # :7:9
%56 : int = aten::len(%55) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
%57 : bool = aten::ne(%56, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
= prim::If(%57) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:80:12
→ ()
block1():
→ ()
%out.25 : Tensor = quantized::conv2d_relu(%out.17, %self.conv1.pointwise._packed_params, %self.conv1.pointwise.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:85:15
%x.9 : Tensor = aten::avg_pool2d(%out.25, %14, %14, %13, %18, %19, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/pooling.py:615:15
%60 : int[] = aten::size(%x.9) # :7:9
%61 : int = aten::len(%60) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
%62 : bool = aten::ne(%61, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
= prim::If(%62) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:421:12
→ ()
block1():
→ ()
%out.33 : Tensor = quantized::conv2d(%x.9, %self.conv2.depthwise._packed_params, %self.conv2.depthwise.scale, %self.conv2.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:426:15
%64 : int[] = aten::size(%out.33) # :7:9
%65 : int = aten::len(%64) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
%66 : bool = aten::ne(%65, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
= prim::If(%66) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:80:12
→ ()
block1():
→ ()
%out.41 : Tensor = quantized::conv2d_relu(%out.33, %self.conv2.pointwise._packed_params, %self.conv2.pointwise.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:85:15
%x.13 : Tensor = aten::avg_pool2d(%out.41, %14, %14, %13, %18, %19, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/pooling.py:615:15
%69 : int[] = aten::size(%x.13) # :7:9
%70 : int = aten::len(%69) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
%71 : bool = aten::ne(%70, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:11
= prim::If(%71) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:420:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:421:12
→ ()
block1():
→ ()
%out.3 : Tensor = quantized::conv2d(%x.13, %self.conv3.depthwise._packed_params, %self.conv3.depthwise.scale, %self.conv3.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py:426:15
%73 : int[] = aten::size(%out.3) # :7:9
%74 : int = aten::len(%73) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
%75 : bool = aten::ne(%74, %22) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:11
= prim::If(%75) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:79:8
block0():
= prim::RaiseException(%23) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:80:12
→ ()
block1():
→ ()
%out.9 : Tensor = quantized::conv2d_relu(%out.3, %self.conv3.pointwise._packed_params, %self.conv3.pointwise.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py:85:15
%x.17 : Tensor = aten::avg_pool2d(%out.9, %14, %14, %13, %18, %19, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/modules/pooling.py:615:15
%x.21 : Tensor = aten::dequantize(%x.17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/init.py:84:15
%dim.1 : int = aten::dim(%x.21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2371:10
%80 : bool = aten::lt(%dim.1, %16) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2372:7
= prim::If(%80) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2372:4
block0():
%81 : str = aten::format(%11, %dim.1) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2374:51
= prim::RaiseException(%81) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2373:8
→ ()
block1():
→ ()
%82 : Tensor = aten::mul(%x.21, %x.21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2379:10
%div.1 : Tensor = aten::unsqueeze(%82, %21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2379:10
%84 : bool = aten::eq(%dim.1, %16) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2380:7
%div : Tensor = prim::If(%84) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2380:4
block0():
%86 : int = aten::dim(%div.1) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:28
%87 : bool = aten::le(%20, %86) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:11
= prim::If(%87) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:4
block0():
→ ()
block1():
= prim::RaiseException(%15) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:4
→ ()
%88 : Tensor = aten::constant_pad_nd(%div.1, %7, %24) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4153:15
%89 : Tensor = aten::avg_pool2d(%88, %6, %5, %13, %18, %19, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2382:14
%div.11 : Tensor = aten::squeeze(%89, %21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2382:14
→ (%div.11)
block1():
%sizes.1 : int[] = aten::size(%x.21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2384:16
%92 : int = aten::getitem(%sizes.1, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2385:23
%93 : int = aten::getitem(%sizes.1, %21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2385:36
%94 : int = aten::getitem(%sizes.1, %20) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2385:46
%95 : int[] = prim::ListConstruct(%92, %21, %93, %94, %26)
%div.17 : Tensor = aten::view(%div.1, %95) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2385:14
%97 : int = aten::dim(%div.17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:28
%98 : bool = aten::le(%16, %97) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:11
= prim::If(%98) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:4
block0():
→ ()
block1():
= prim::RaiseException(%15) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4151:4
→ ()
%99 : Tensor = aten::constant_pad_nd(%div.17, %4, %24) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:4153:15
%100 : Tensor = aten::avg_pool3d(%99, %3, %2, %12, %18, %19, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2387:14
%div.29 : Tensor = aten::squeeze(%100, %21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2387:14
%div.35 : Tensor = aten::view(%div.29, %sizes.1) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2388:14
→ (%div.35)
%103 : Tensor = aten::mul(%div, %10) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2389:10
%104 : Tensor = aten::add(%103, %8, %21) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2389:10
%div.49 : Tensor = aten::pow(%104, %9) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2389:10
%x.25 : Tensor = aten::div(%x.21, %div.49) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:2390:11
%x.29 : Tensor = aten::quantize_per_tensor(%x.25, %25, %self.conv1.depthwise.zero_point, %self.quant.dtype) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/init.py:52:15
%108 : int[] = aten::size(%x.29) # :7:9
%109 : int = aten::getitem(%108, %self.conv1.depthwise.zero_point) # :40:19
%110 : int[] = prim::ListConstruct(%109, %26)
%x.35 : Tensor = aten::view(%x.29, %110) # :40:12
%x.39 : Tensor = quantized::linear_relu(%x.35, %self.fc1._packed_params._packed_params, %self.fc1.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py:28:15
%x.43 : Tensor = quantized::linear_relu(%x.39, %self.fc2._packed_params._packed_params, %self.fc2.scale, %self.conv1.depthwise.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py:28:15
%x.47 : Tensor = quantized::linear(%x.43, %self.fc3._packed_params._packed_params, %self.fc3.scale, %self.fc3.zero_point) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/linear.py:168:15
%x.51 : Tensor = aten::dequantize(%x.47) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/quantized/modules/init.py:84:15
%x.55 : Tensor = aten::softmax(%x.51, %21, %17) # /home/lakshya/.local/lib/python3.7/site-packages/torch/nn/functional.py:1679:14
return (%x.55)

I noticed that Quantization of MobileNet-V2 model was working fine for me, just like in the pytorch tutorial https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html?highlight=static%20quantization.
So, I started debugging by comparing with Mobilenet-V2. I noticed two changes, that helped my model to also achieve better performance.

  1. using _make_divisible(v, divisor, min_value=None) function(in the MobileNet-V2 code) or manually setting the number of channels to be divisible by 8.
  2. use padding = (kernel_size - 1) // 2 in convolution layers, instead of padding=0.

I changed the model to following:

class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        #padding = 0
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=True),
            #nn.BatchNorm2d(out_planes, momentum=0.1),
            # Replace with ReLU
            nn.ReLU(inplace=False)
        )
    def fuse(self):
        fuse_modules(self, ['0', '1'], inplace=True)
    
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        block = [
            ConvBNReLU(1,8,3,2),
            ConvBNReLU(8,16,3,1,8),
            nn.Conv2d(16,32,1,1),
            ConvBNReLU(32,8,1,1),
        ]
        self.features = nn.Sequential(*block)
        self.classifier = nn.Linear(8,2)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        x = self.dequant(x)
        return x

and achieved ~38MB reduction in memory in Android by using Quantization. by using the above 2 points.