Hi,
I’ve run into an issue where a sparse matrix multiplication causes NaN values for no good reason. I can’t share the source code and the data, plus it doesn’t always happen so it’s quite hard to reproduce.
I’m just wondering if anyone else has encountered anything similar? It’s so random and I’ve no idea why it sometime breaks and sometimes not.
Here’s an example of the code that breaks:
def my_foo(a: Tensor, b: Tensor, sparse_matrix: Tensor) -> Tensor:
c = a + 1j * b # torch.complex128
assert not torch.isnan(c).any() # all good, no nans
# originaly torch.int64 (a matrix of ones and zeros), sparse csr
sparse_matrix = sparse_matrix.to(c.dtype)
assert not torch.isnan(sparse_matrix .values()).any() # all good, no nans
# the dimensions are c: (32700,), and sparse_matrix: (302200, 32700) so quite large
result = sparse_matrix @ c
assert not torch.isnan(result ).any() # this assertion fails, there are nans (some entries are nans, not all)
return result
So neither sparse_matrix
nor c
contain NaNs, yet when multyplying them NaNs occur.