I need to convert a torch model to TFLite and accelerate it on an Android device using the NNAPI. I can perform the conversion itself, but my model has an
argmax node in it that is causing trouble. The reason it’s causing trouble, I’m quite sure, is that the default and only return type of
torch.argmax is a
LongTensor, i.e. an
int64. Android NN docs show that only
ints up to
32-bit are supported.
Typically I would just cast the output to whatever data type I needed. But because the whole graph is trying to be loaded by the NNAPI, it ends up breaking the model into partitions and sending everything after the argmax to the CPU. This causes a big slowdown. Is there any way I can avoid generating an
int64 in the forward pass at all?