Empty tensors, argmax and ONNX exporting

I’m trying to export a model to ONNX format. One component of the model is that, for an n by n boolean matrix M, I need to determine the index of the first True in every row. The easy way to do this is just torch.argmax(M, dim=-1). However, M is sometimes empty, and torch.argmax throws an error. It seems like it would be easy to fix this by just checking M.numel() == 0, but I get the following warning:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

I think this means that the ONNX model will always set the result of M.numel() == 0 to either True or False, which is definitely not what I want. Is there some way to avoid this issue? I’ve been trying to think of ways to find the first nonzero element without using any min or max functions, but it does not seem obvious.

1 Like

Turns out that one thing that seems to work is torch.equal(M, torch.empty(0,0)) rather than M.numel() == 0