How to use where function like Numpy

I ran the following code, got an error.

Numpy allows us to use where function without 2 arguments, but Pytorch not.
What I want to do is to select the rows corresponding to the value what I put.
I will really appreciate that if you would told me how to use this function.

From looking at numpy’s doc when only a single argument is given, it is equivalent to condition.nonzero(). So just do (a[:, 0] == 1).nonzero() ?

1 Like