Sparse matrix multiplication randomly causes NaNs

Hi,

I’ve run into an issue where a sparse matrix multiplication causes NaN values for no good reason. I can’t share the source code and the data, plus it doesn’t always happen so it’s quite hard to reproduce.

I’m just wondering if anyone else has encountered anything similar? It’s so random and I’ve no idea why it sometime breaks and sometimes not.

Here’s an example of the code that breaks:

    def my_foo(a: Tensor, b: Tensor, sparse_matrix: Tensor) -> Tensor:

        c = a + 1j * b # torch.complex128
        assert not torch.isnan(c).any() # all good, no nans
        
        # originaly torch.int64 (a matrix of ones and zeros), sparse csr
        sparse_matrix = sparse_matrix.to(c.dtype)
        assert not torch.isnan(sparse_matrix .values()).any() # all good, no nans
        
        # the dimensions are c: (32700,), and sparse_matrix: (302200, 32700) so quite large
        result = sparse_matrix @ c

        assert not torch.isnan(result ).any() # this assertion fails, there are nans (some entries are nans, not all)
        return result

So neither sparse_matrix nor c contain NaNs, yet when multyplying them NaNs occur.