Problem with multiplying a batch of sparse tensor with a batch of dense tensors

Hi,

I am facing some problems regarding batch sparse multiplication. I have a batch of sparse graphs, A ( n x 10000 x 1000). Which I am storing as list of sparse tensors, A_s = [ sparse_tensor(A_1), …]. My input matrix, X has a dimension of [n x 10000 x 10]. Now I want to do batch multiplication, i.e. the first graph gets multiplied with the input from subject 1. To do that I am first putting all my graphs A_i on the diagonals of a big graph graph, and I am manually keeping track of the non-zero indices and non-zero values. Then I create this Big sparse matrix, A_big whose diagonals are all those graphs. The dense version of this matrix has size (n10000, n10000). After, that I am resizing my X → X_modified ( n*10000 x 10). Now, I am performing the torch.sparse.mm( A_big, X_modified). During forward pass everything works fine. But, during backward pass I am getting the error

"RuntimeError: CUDA out of memory. Tried to allocate 2875.53 GiB (GPU 0; 3.94 GB total capacity; 1.31 GiB already allocated; 1.68 GiB free; 1.44 GiB reserved in total by PyTorch)"

However, if I try to do the same this using a for loop and performing the sparse multiplication individually on each subject I am not getting this error.

To give some context, the non-zero values of my sparse graph are learnable using some attention mechanism, and the number of learnable parameters there are 10. Using the sparse implementation I am storing only the non-zero values so I don’t understand why I am getting this error.

Any help is really appreciated!

Thank you!