Hi Maryam!
Note, you can use a Linear and apply it to a sparse tensor. (You don’t
need to explicitly call something like torch.sparse.addmm().)
You don’t have to use Linear, of course, but it might be stylistically
better since it’s used so commonly.
If it better fits your use case, you can also use the functional form,
torch.nn.functional.linear().
Consider:
>>> import torch
>>> torch.__version__
'1.12.0'
>>> _ = torch.manual_seed (2022)
>>> s = ((torch.randn (3, 5) > 0.0) * torch.randn (3, 5)).to_sparse() # make a sparse tensor
>>> s
tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2],
[0, 1, 2, 3, 2, 3, 2, 3, 4]]),
values=tensor([-1.0556, -0.5344, 1.2262, 1.1603, -0.1418, -0.3402,
0.6979, -0.5543, -0.2561]),
size=(3, 5), nnz=9, layout=torch.sparse_coo)
>>> lin = torch.nn.Linear (5, 3) # make a (non-sparse) Linear
>>> t = lin (s)
>>> t
tensor([[-0.4231, -0.6271, -0.6817],
[-0.2014, -0.0587, 0.1683],
[-0.3242, 0.1646, 0.3315]], grad_fn=<AddmmBackward0>)
>>> t.sum().backward() # backward works
>>> lin.weight.grad
tensor([[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561]])
>>> w = torch.nn.Parameter (lin.weight.detach().clone())
>>> b = torch.nn.Parameter (lin.bias.detach().clone())
>>> t = torch.nn.functional.linear (s, w, b) # functional form of linear also works
>>> t.sum().backward()
>>> w.grad
tensor([[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561]])
Best.
K. Frank