I have a batch of N rows each of M values that are sorted along dim=1. For each row, I want to find the first nonzero element index from M sorted values. I’d like to do it efficiently without the for-loop.

x = torch.randn(5, 7)
x[x<0] = 0
x = x.sort(dim=1)
first_nonzero = f(x)

Here is another try, especially if you want to only use torch APIs.

import torch
def f(x):
# non zero values mask
non_zero_mask = x != 0
# operations on the mask to find first nonzero values in the rows
mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1)
# if the max-mask is zero, there is no nonzero value in the row
mask_max_indices[mask_max_values == 0] = -1
return mask_max_indices
x = torch.randn(4, 5)
x[x<0] = 0
x, sort_indices = x.sort(dim=1)
print('x', x)
first_nonzero = f(x)
print(first_nonzero)

Thanks for replies.
I’m interested in pytorch-only implementation to keep my code eco-friendly. Sorry for not mentioning this beforehand. @InnovArul, your solution implies that torch.max always returns the first occurrence. I don’t find the docs mention that. Though it works with current version 0.4.1, I think it’s better to avoid accounting for internal implementation of pytorch functional. For that reason, @adrianjav answer looks more compatible as long as ByteTensor supports sum operation.

I was looking for a solution that does not assume the input is already sorted. Since this assumption was not mentioned in the topic title, it may be useful to have a solution that works in that case as well. I think the following should work in the general case. The idea is that an element is the first nonzero element if it is nonzero and the cumulative sum of a nonzero indicator is 1.

import torch
def first_nonzero(x, axis=0):
nonz = (x > 0)
return ((nonz.cumsum(axis) == 1) & nonz).max(axis)
x = (torch.rand(10, 5) * 10 - 6).int().clamp(0, 10)
print (x)
# Function returns if there are any nonzero's and the index of the first nonzero (0 if no nonzero)
any_nonz, idx_first_nonz = first_nonzero(x, axis=1)
print (any_nonz, idx_first_nonz)
# If you want -1 for rows with no nonzero's
idx_first_nonz[any_nonz == 0] = -1
print (idx_first_nonz)

@wouter’s solution above was what I was looking for.

I would like to point out two things this though.

First, it actually finds the first nonnegative element, not the first nonzero element. I.e. it finds the first element that is zero or larger.

Second, it has a failure case when all elements are negative along the axis. In this case it returns 0, which indicates that there is a nonnegative element at index zero, which is wrong.

One way to improve this is to find the places at which the sum of (x>0) along axis equals zero. This indicates that all elements along the axis were negative. We can then set those to some different value, e.g. -1 or `float(‘nan’)’. We can use the already computed cumulative sum to do this.

def first_nonnegative(tensor, axis=0, value_for_all_nonnegative=-1):
nonnegative = (tensor > 0)
cumsum = nonnegative.cumsum(axis=axis)
all_negative = cumsum[-1] == 0 # Any dimensions where all are negative
nonnegative_idx = ((cumsum == 1) & nonnegative).max(0).indices
nonnegative_idx[all_negative] = value_for_all_nonnegative
return nonnegative_idx

Glad that it was of help to you! You are right it finds the first non-negative value (I guess I implicitly assumed only non-negative values), but you can easily change nonz = (x > 0) for nonz = (x != 0).

The failure case you describe I already accounted for by returning any_nonz as well, which you can use to set rows without any nonzero to -1 by idx_first_nonz[any_nonz == 0] = -1 as done in my example.