How to implement torch.argmax in java

Here is the code that I wanna implement in java for the PyTorch mobile application, but there is seem like no related API.

pred = torch.argmax(outputs[0], 1)
pred = pred.cpu().data.numpy()
predict = pred.squeeze(0)

java

final IValue[] outputTuple = module.forward(IValue.from(inputTensor)).toTuple();
final Tensor outputTensor = outputTuple[0].toTensor();
Log.d(TAG, "onCreate: " + outputTensor.getDataAsFloatArray());

The FloatArray is 1-dim array, it really rack my brains, Could anyone please tell me a clue. thank you

Put the code in forward method solved this problem. :zipper_mouth_face:

def forward(self, x):
        ...
        pred = torch.argmax(outputs[0], 1)
        pred = pred.cpu().data
        predict = pred.squeeze(0)
        predict = predict.type(torch.FloatTensor)

        return predict

Run the traced_script and get a new model.pt, put it in assets, the output tensor will change and could be use directly.