Torch.where not working for float('nan')

var = torch.tensor([float('nan'), 1.0])

torch.where(var == float('nan'))  gives wrong output?

(tensor([], dtype=torch.int64),)

This is expected, since NaN == NaN returns False by definition:

x = torch.tensor(float("nan"))
x == x
# tensor(False)

a = np.array(np.nan)
a == a
# False

float("nan") == float("nan")
# False

THanks for the response!
So, there is no direct way to get the indices of NaN in torch. I have to find indirect ways like replacing NaNs with specific values and then doing torch.where() for that replace value?

You could use torch.isnan:

torch.where(torch.isnan(var))
# (tensor([0]),)
1 Like

Thank you for quick solution!