Hi,
I’m trying to calculate a gradient w.r.t a sparse matrix. It seems like pytorch’s autograd doesn’t support getting the gradient for sparse matrix so I want to calculate it manually if it’s possible.
the forward function is softmax(A*AXW). A is a sparse matrix and I want to calculate the gradient w.r.t A
I think it’d be easier to read if you wrote it as a PyTorch output, I cannot quite make out what A, X, and W are and what is the exact expression you have in the softmax.
In general the result of matrix multiplication Y = A @ B and some scalar-valued function f(Y) following it has gradients df/dA = (df/dY)@ B.t() , df/dB = A.t() @ (df/dY). Around PyTorch df/dY is sometimes called grad_out for the matrix multiplication and you see backpropagation at work…
The problem is likely (and ha, now the formula is useful after all) that the to computation df/dA = (df/dY)@ B.t() is in all dense matrices and you don’t, in general, have the sparseness same pattern in the gradient. What PyTorch does under the hood is to compute the dense derivative (large) and then apply the sparseness pattern.
You could try to implement your own sparse_mm with backwards using scatter_add from the 3rd party PyTorch scatter package or somesuch.
Note that the sparse derivative “enforces” the sparseness pattern (this is often desired even though the there may be gradients outside the mask).
Hi, did you mean only non-zeros of the original sparse matrix have the gradient of the dense matrix df/dA = (df/dY)@ B.t() ? I mean, does the dense matrix df/dA has the same sparse pattern as A? Thanks
Mathematically, it does not. To the mathematical gradient, the “missing” entries in the sparsity pattern are just things that happen to be zero. The computation
a = torch.sparse_coo_tensor(torch.tensor([[1, 2], [0, 1]]), torch.randn(2), size=(3, 3), requires_grad=True)
b = torch.randn(2, 3, requires_grad=True)
torch.sparse.mm(a,b.t()).sum().backward()
gives you an a.grad with the same sparsity pattern as a, so it essentially computes the gradient of a masked tensor (because mathematically, the gradient a * mask w.r.t. a is the mask).
From their code (SPMMFunction from Line 8), I found they only computed the gradient of the dense matrix for the backward. Do you mean dy/dA is also required? so grad_out @ b.t() should be returned in backward (Line36) as well?