I can get the index of matrix elements where the values hold particular condition, by:
import numpy as np
matrix = np.random.randint(0,8,(4,4))
inds =np.where(matrix<5)
How do I achieve the same result using torch::where(.)
?
It’s seem that torch::where(.)
has different api with np.where(.)
.