How to most efficiently select/modify the indices given by .nonzero()?

Suppose I have a tensor x, which is mostly zeros, but does have some elements different from zero.

I can get the indices for those elements by using indices = x.nonzero()

However, now I would like to increase the values of all those elements by 1, how do I most efficiently go about doing that?

x[indices] += 1 does not work.
I can of course do it with a for loop iterating over one index in the indices at a time, but that seems inefficient, and I’m sure that there must be an elegant way of selecting all the indices returned by nonzero() that I just don’t know about?

You can use the torch.index_add_() function to efficiently increase the values of the non-zero elements in the tensor by 1.

import torch

# Create a tensor with some non-zero elements
x = torch.tensor([[1, 0, 0], [0, 2, 0], [0, 0, 3]])

# Get the indices of the non-zero elements
indices = x.nonzero()

# Use torch.index_add_ to increase the values of the non-zero elements by 1
torch.index_add_(x, 0, indices, torch.ones(indices.size(0)))

print(x)  # tensor([[2, 0, 0], [0, 3, 0], [0, 0, 4]])

The first argument to torch.index_add_ is the tensor that you want to modify. The second argument specifies the dimension along which the indices apply. In this case, we are using 0 because the indices apply along the first dimension. The third argument is the tensor of indices, and the fourth argument is the tensor of values to add to the elements at the specified indices.

Using torch.index_add_ is much more efficient than using a for loop because it can perform the operation in parallel.

1 Like