How to change default return type of `torch.argmax`

I need to convert a torch model to TFLite and accelerate it on an Android device using the NNAPI. I can perform the conversion itself, but my model has an argmax node in it that is causing trouble. The reason it’s causing trouble, I’m quite sure, is that the default and only return type of torch.argmax is a LongTensor, i.e. an int64. Android NN docs show that only ints up to 32-bit are supported.

Typically I would just cast the output to whatever data type I needed. But because the whole graph is trying to be loaded by the NNAPI, it ends up breaking the model into partitions and sending everything after the argmax to the CPU. This causes a big slowdown. Is there any way I can avoid generating an int64 in the forward pass at all?