Get index info from torch.where

Hello. I found torch. where (condition , x , y ) [torch.where — PyTorch 1.9.1 documentation] seemed doesn’t support getting the index location of each value found under a specific condition, as torch. max [torch.max — PyTorch 1.9.1 documentation] and torch. min [torch.min — PyTorch 1.9.1 documentation]has supported. Is there a way to get the index when calling torch. where or are there any alternatives? Thanks.

try torch.nonzero(condition), perhaps with as_tuple argument