I’m trying to export a model to ONNX format. One component of the model is that, for an
n boolean matrix
M, I need to determine the index of the first
True in every row. The easy way to do this is just
torch.argmax(M, dim=-1). However,
M is sometimes empty, and
torch.argmax throws an error. It seems like it would be easy to fix this by just checking
M.numel() == 0, but I get the following warning:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
I think this means that the ONNX model will always set the result of
M.numel() == 0 to either
False, which is definitely not what I want. Is there some way to avoid this issue? I’ve been trying to think of ways to find the first nonzero element without using any
max functions, but it does not seem obvious.