I’m trying to export a model to ONNX format. One component of the model is that, for an n
by n
boolean matrix M
, I need to determine the index of the first True
in every row. The easy way to do this is just torch.argmax(M, dim=-1)
. However, M
is sometimes empty, and torch.argmax
throws an error. It seems like it would be easy to fix this by just checking M.numel() == 0
, but I get the following warning:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
I think this means that the ONNX model will always set the result of M.numel() == 0
to either True
or False
, which is definitely not what I want. Is there some way to avoid this issue? I’ve been trying to think of ways to find the first nonzero element without using any min
or max
functions, but it does not seem obvious.