Dealing with unsupported operations with NNAPI

Hi,

I am trying to run a model on android with NNAPI, however I have to deal with some unsupported simple constructions:

  • Addition between constant and tensor (e.g. X + 1)
  • Division of tensor by a constant (e.g. X / 4)

For the first one, is there any easy way to convert it to be supported by NNAPI? (I tried creating a torch.ones() to replace the constant, but creating new tensor while running on NNAPI doesn’t appear to be supported)

For the second one, I understand that dividing tensors is not currently supported for NNAPI. Is there a clean way to call a non-NNAPI module/function from a module running through NNAPI in order to still perform these operation?

Thanks,

Julien

Any suggestions on the matter?

Hi there, apologies for the delay.

For addition, that seems to be something we need to fix. Looks like X + constant (tensor or scalar) is failing. X + Y works where Y can be torch.ones(1).

Division op is in the works! Dividing by constant might have the same issue but I’ll take a look.

Hi,

Thanks for the reply. The torch.ones(1) does actually seems to work but only for non-quantized models:

class TestNNAPI(torch.nn.Module):
    def __init__(self):
        super(TestNNAPI, self).__init__()
        self.conv = torch.nn.Conv2d(1, 6,  kernel_size=3, padding=1, stride=1)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        res = x + torch.ones(1)
        return res

Tracing and converting this to NNAPI seems to work smoothly, however if quantization (using ‘qnnpack’ backend at least) is done on the model I get the following error:

Traceback (most recent call last):
  File "./test-constant.py", line 36, in <module>
    traced_nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced_nnapi_model, testinput)
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/prepare.py", line 85, in convert_model_to_nnapi
    ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count = serialize_model(model, inputs)
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 1739, in serialize_model
    return _NnapiSerializer(config).serialize_model(module, inputs)
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 628, in serialize_model
    self.add_node(node)
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 793, in add_node
    adder(self, node)
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 750, in <lambda>
    self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE),
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 1119, in add_add_sub_op
    self._do_add_binary(node, opcode, fuse_code)
  File "<install>/.local/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 1074, in _do_add_binary
    in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1))
  File "<install>/lib/python3.6/site-packages/torch/backends/_nnapi/serializer.py", line 456, in get_tensor_operand_by_jitval
    operand_id = self.jitval_operand_map[jitval]
KeyError: 31 defined in (%31 : Float(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]()
)

Which I find a bit surprising since the ones is not even used in quantized context. (And if I try to have a quantized torch.ones, I get the error Could not run 'aten::empty.memory_format' with arguments from the 'QuantizedCPU' backend.

Is this a known limitation?

Can you post the full script you’re using to reproduce this?

Yes, here’s the reproducer:

#! /usr/bin/python3                                                                                                                                                                                                                    

import torch
import torch.backends._nnapi.prepare

class TestNNAPI(torch.nn.Module):
    def __init__(self):
        super(TestNNAPI, self).__init__()
        self.conv = torch.nn.Conv2d(1, 6,  kernel_size=3, padding=1, stride=1)
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.dequant(x)
        res = x + torch.ones(1)
        return res

nnapi_model = TestNNAPI()
nnapi_model.eval()
nnapi_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.backends.quantized.engine = 'qnnpack'

testinput = torch.rand(6, 1, 600, 800)
testinput.nnapi_nhwc = True

torch.quantization.prepare(nnapi_model, inplace=True)
nnapi_model(testinput)
torch.quantization.convert(nnapi_model, inplace=True)

traced_nnapi_model = torch.jit.trace(nnapi_model, testinput)
traced_nnapi_model = torch.backends._nnapi.prepare.convert_model_to_nnapi(traced_nnapi_model, testinput)

If I remove the quantization.prepare/convert , the script runs fine.

Hi David,

Do you have any update regarding the reproducer? Is there a workaround for this issue?

Thanks,

Julien

Sorry for the late reply. I’ll try to add support for this.

I was able to get this working (without broadcasting) with this change: change.patch · GitHub

It’s a bit half-baked, though. I’ll do some cleanup and should get this landed in time for the next release.