IndexError using nonzero as an index

I am trying to reduce a tensor down to only the non-zero elements. If I am reading the docs correctly I should be able to do the following

# y_true has torch.Size([1, 6, 480, 640])
idx = torch.nonzero(torch.any(torch.greater(y_true, 0.0), dim=1), as_tuple=True)
y_true = y_true[idx] 

However, when I do, I get this error

IndexError: index 324 is out of bounds for dimension 1 with size 6

What have I missed?

Hi Alex!

The problem is that torch.any (t, dim = 1) “reduces” away a
dimension, that is, it returns a result with one fewer dimension than
has t.

In your case, torch.any() returns a tensor with shape [1, 480, 640].
torch.nonzero() then produces idx, a tuple of length 3, where idx[1]
corresponds to the size-480 dimension of y_true, and therefore can
contain values up to 479. But you then use idx[1] to index into the
size-6 dimension of y_true, hence the IndexError.

I’m not sure what your use case is, but let me note that we know that
idx[0] contains only values that are equal to 0 (because the 0th
dimension of y_true has size 1). So your code would “work,” for
example, with

y_true = y_true[(idx[0], idx[0], idx[1], idx[2])]

(but this probably isn’t what you want).

(Note, there is nothing subtle going on by using idx[0] twice – it’s just
a hack where we use idx[0] as a handy all-zero tensor that can validly
index into the size-6 dimension of y_true.)

If you could illustrate your desired use case with a complete, runnable
script (including hard-coded or random data and dimensions much
smaller than 480 and 640) that uses loops, if necessary, to produce
the result you want, we can probably help you figure out how to do
it with no-loop tensor operations.


K. Frank

You are correct, replicating idx[0] is not desirable in this case. The first dimension having size 1 is just happenstance (this is the batch dimension and will ultimately be larger than 1).

I am basically trying to replicate this tensorflow code

idx = tf.squeeze(tf.where(tf.reduce_any(tf.greater(y_true, 0.0), axis=-1)), axis=-1)
y_true = tf.gather(y_true, idx)