how can we calculate the required memory for torch.sparse.matmul
?
a dense (m, k) * (k, n)
matmul will create an (m, n)
shaped output tensor. but I can’t figure out how much memory a sparse-sparse matmul requires. it seems to be some function of tensor._nnz()
, but I want to know what it exactly is. Sometimes torch.sparse.matmul
tries to allocate an unreasonable amount of memory and results in OOM error.