First nonzero index

I have a batch of N rows each of M values that are sorted along dim=1. For each row, I want to find the first nonzero element index from M sorted values. I’d like to do it efficiently without the for-loop.

x = torch.randn(5, 7)
x[x<0] = 0
x = x.sort(dim=1)
first_nonzero = f(x)
1 Like

Hi, I coded a solution for your problem. You can check it in the link below.

Hope this works out for you.

2 Likes

Here is another try, especially if you want to only use torch APIs.

import torch

def f(x):
    # non zero values mask
    non_zero_mask = x != 0

    # operations on the mask to find first nonzero values in the rows
    mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1)

    # if the max-mask is zero, there is no nonzero value in the row
    mask_max_indices[mask_max_values == 0] = -1
    return mask_max_indices

x = torch.randn(4, 5)
x[x<0] = 0
x, sort_indices = x.sort(dim=1)
print('x', x)
first_nonzero = f(x)

print(first_nonzero)
2 Likes

It is easier if you count the number of zero elements in that dimension

x = torch.randn(5, 7)
x[x<0] = 0
x = x.sort(dim=1)[0] # You forgot that sort returns a pair
first_nonzero = (x == 0).sum(dim=1)

Even easier, you can skip the x[x<0] = 0 line and count the non-positive elements:

x = torch.randn(5, 7)
x = x.sort(dim=1)[0]
first_positive = (x <= 0).sum(dim=1)
6 Likes

In the case that there are no positive numbers in a row, the answer might be misleading. Just a minor caveat to be handled I guess.

True, in that case first_nonzero[i] == x.size(1). I don’t think is a caveat tho, since you’ll have to mark those cases somehow.

Thanks for replies.
I’m interested in pytorch-only implementation to keep my code eco-friendly. Sorry for not mentioning this beforehand.
@InnovArul, your solution implies that torch.max always returns the first occurrence. I don’t find the docs mention that. Though it works with current version 0.4.1, I think it’s better to avoid accounting for internal implementation of pytorch functional. For that reason, @adrianjav answer looks more compatible as long as ByteTensor supports sum operation.

yes. I like @adrianjav’s answer as well. It looks more fail safe to me. All the best!

1 Like

I was looking for a solution that does not assume the input is already sorted. Since this assumption was not mentioned in the topic title, it may be useful to have a solution that works in that case as well. I think the following should work in the general case. The idea is that an element is the first nonzero element if it is nonzero and the cumulative sum of a nonzero indicator is 1.

import torch

def first_nonzero(x, axis=0):
    nonz = (x > 0)
    return ((nonz.cumsum(axis) == 1) & nonz).max(axis)

x = (torch.rand(10, 5) * 10 - 6).int().clamp(0, 10)
print (x)

# Function returns if there are any nonzero's and the index of the first nonzero (0 if no nonzero)
any_nonz, idx_first_nonz = first_nonzero(x, axis=1)
print (any_nonz, idx_first_nonz)

# If you want -1 for rows with no nonzero's
idx_first_nonz[any_nonz == 0] = -1
print (idx_first_nonz)
3 Likes

@wouter’s solution above was what I was looking for.

I would like to point out two things this though.

First, it actually finds the first nonnegative element, not the first nonzero element. I.e. it finds the first element that is zero or larger.

Second, it has a failure case when all elements are negative along the axis. In this case it returns 0, which indicates that there is a nonnegative element at index zero, which is wrong.

One way to improve this is to find the places at which the sum of (x>0) along axis equals zero. This indicates that all elements along the axis were negative. We can then set those to some different value, e.g. -1 or `float(‘nan’)’. We can use the already computed cumulative sum to do this.

def first_nonnegative(tensor, axis=0, value_for_all_nonnegative=-1):
    nonnegative = (tensor > 0)
    cumsum = nonnegative.cumsum(axis=axis)
    all_negative = cumsum[-1] == 0  # Any dimensions where all are negative
    nonnegative_idx = ((cumsum == 1) & nonnegative).max(0).indices
    nonnegative_idx[all_negative] = value_for_all_nonnegative
    return nonnegative_idx

Glad that it was of help to you! You are right it finds the first non-negative value (I guess I implicitly assumed only non-negative values), but you can easily change nonz = (x > 0) for nonz = (x != 0).

The failure case you describe I already accounted for by returning any_nonz as well, which you can use to set rows without any nonzero to -1 by idx_first_nonz[any_nonz == 0] = -1 as done in my example.

Haha! Don’t know how I missed that in your example, sorry for ranting on with that.

And yea the other thing is pretty easy to change :ok_hand: