Use torch.nonzero() as index

Here is a numpy example to use nonzero.

    a = np.random.randint(6,size=(4,5,3))
    idx = np.nonzero(a)
    a[idx] = 0

This is PyTorch

    a = torch.randint(6,size=(4,5,3))
    idx = torch.nonzero(a)  # idx = a.nonzero()
    # a[idx] = 0  # this throws an error because `nonzero` cannot be used as index
    a[(idx[:, 0], idx[:, 1], idx[:, 2])] = 0  # this is a little counter-intuitive, can we instead accept multi-dim tensors as index?

I actually like torch.nonzero's single tensor better than the tuple that numpy returns, it is more elegant. But I do not know of a better way to use it as index

1 Like

Maybe it is a bit unfortunate that nonzero has the dimensions in that order… :confused:
You can make the split explicit and have something that is close to NumPy’s nonzero indexing:

a = torch.threshold(torch.randn(5, 5, 5), 1, 0)
idx = torch.nonzero(a).split(1, dim=1)
a[idx]

Best regards

Thomas

6 Likes

Thanks! .split makes it easier. As a suggestion if PyTorch tensors could also be indexed by another tensor (of coordinates) that could be even more intuitive than numpy.

You just need to set as_tuple=True:

a[ a.nonzero(as_tuple=True) ]

6 Likes

Hurray, this is now available!