var = torch.tensor([float('nan'), 1.0])
torch.where(var == float('nan')) gives wrong output?
(tensor([], dtype=torch.int64),)
var = torch.tensor([float('nan'), 1.0])
torch.where(var == float('nan')) gives wrong output?
(tensor([], dtype=torch.int64),)
This is expected, since NaN == NaN
returns False
by definition:
x = torch.tensor(float("nan"))
x == x
# tensor(False)
a = np.array(np.nan)
a == a
# False
float("nan") == float("nan")
# False
THanks for the response!
So, there is no direct way to get the indices of NaN in torch. I have to find indirect ways like replacing NaNs with specific values and then doing torch.where() for that replace value?
You could use torch.isnan
:
torch.where(torch.isnan(var))
# (tensor([0]),)
Thank you for quick solution!