I need to convert a torch model to TFLite and accelerate it on an Android device using the NNAPI. I can perform the conversion itself, but my model has an argmax
node in it that is causing trouble. The reason it’s causing trouble, I’m quite sure, is that the default and only return type of torch.argmax
is a LongTensor
, i.e. an int64
. Android NN docs show that only int
s up to 32
-bit are supported.
Typically I would just cast the output to whatever data type I needed. But because the whole graph is trying to be loaded by the NNAPI, it ends up breaking the model into partitions and sending everything after the argmax to the CPU. This causes a big slowdown. Is there any way I can avoid generating an int64
in the forward pass at all?