Index of the last one in a tensor

I have tensor which looks like

tensor([[1., 1., 0., 1., 0., 0., 1., 1., 0., 1.],
        [1., 1., 0., 1., 0., 0., 1., 1., 0., 0.],
        [1., 1., 0., 1., 0., 0., 0., 0., 0., 1.]])

Whats the best way to find the indices of last one in each row? In the above case, it should return [9, 7, 9]

Suppose your tensor is a. Then you can do

sorted_tensor, sort_idx = a.sort(1)
out = sort_idx[:, -1]
1 Like