How to change default return type of `torch.argmax`

I need to convert a torch model to TFLite and accelerate it on an Android device using the NNAPI. I can perform the conversion itself, but my model has an argmax node in it that is causing trouble. The reason it’s causing trouble, I’m quite sure, is that the default and only return type of torch.argmax is a LongTensor, i.e. an int64. Android NN docs show that only ints up to 32-bit are supported.

Typically I would just cast the output to whatever data type I needed. But because the whole graph is trying to be loaded by the NNAPI, it ends up breaking the model into partitions and sending everything after the argmax to the CPU. This causes a big slowdown. Is there any way I can avoid generating an int64 in the forward pass at all?

1 Like

Hello,

Did you by any chance find a solution to this problem?

Unfortunately not really. The work around was to first convert to ONNX, where I could edit that graph to explicitly use an int32 in the argmax, and then make a jump to TFLite. Frustrating process.