How can I apply a linear transformation on sparse matrix in PyTorch?

In PyTorch, we have nn.linear that applies a linear transformation to the incoming data:

y = WA+b

In this formula, W and b are our learnable parameters and A is my input data matrix. The matrix A for my case is too large for RAM to complete loading, so I use it sparsely. Is it possible to perform such an operation on sparse matrices using PyTorch?

When you say you use it ‘sparsely’, do you mean you are using sparse tensors (i.e., tensors where a large proportion of the values are 0)? If so, then you can just represent your tensors using torch.sparse to do this. And then you can use torch.addmm to implement the linear layer.

1 Like

Thanks for your response. Yes, I use COO format.
When I define b' as bias and W as weight matrix, I use nn.parameter(). When we just have two learnable parameters (Wandb`), and the COO sparse matrix is our input data (not learnable parameter), which of torch.addmm() and torch.sparse.addmm() should I use?

Hi Maryam!

Note, you can use a Linear and apply it to a sparse tensor. (You don’t
need to explicitly call something like torch.sparse.addmm().)

You don’t have to use Linear, of course, but it might be stylistically
better since it’s used so commonly.

If it better fits your use case, you can also use the functional form,
torch.nn.functional.linear().

Consider:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> _ = torch.manual_seed (2022)
>>> s = ((torch.randn (3, 5) > 0.0) * torch.randn (3, 5)).to_sparse()   # make a sparse tensor
>>> s
tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2],
                       [0, 1, 2, 3, 2, 3, 2, 3, 4]]),
       values=tensor([-1.0556, -0.5344,  1.2262,  1.1603, -0.1418, -0.3402,
                       0.6979, -0.5543, -0.2561]),
       size=(3, 5), nnz=9, layout=torch.sparse_coo)
>>> lin = torch.nn.Linear (5, 3)   # make a (non-sparse) Linear
>>> t = lin (s)
>>> t
tensor([[-0.4231, -0.6271, -0.6817],
        [-0.2014, -0.0587,  0.1683],
        [-0.3242,  0.1646,  0.3315]], grad_fn=<AddmmBackward0>)
>>> t.sum().backward()   # backward works
>>> lin.weight.grad
tensor([[-1.0556, -0.5344,  1.7823,  0.2658, -0.2561],
        [-1.0556, -0.5344,  1.7823,  0.2658, -0.2561],
        [-1.0556, -0.5344,  1.7823,  0.2658, -0.2561]])
>>> w = torch.nn.Parameter (lin.weight.detach().clone())
>>> b = torch.nn.Parameter (lin.bias.detach().clone())
>>> t = torch.nn.functional.linear (s, w, b)   # functional form of linear also works
>>> t.sum().backward()
>>> w.grad
tensor([[-1.0556, -0.5344,  1.7823,  0.2658, -0.2561],
        [-1.0556, -0.5344,  1.7823,  0.2658, -0.2561],
        [-1.0556, -0.5344,  1.7823,  0.2658, -0.2561]])

Best.

K. Frank

1 Like

Thank you so much. It was very helpful.