Hi Maryam!
Note, you can use a Linear
and apply it to a sparse tensor. (You don’t
need to explicitly call something like torch.sparse.addmm()
.)
You don’t have to use Linear
, of course, but it might be stylistically
better since it’s used so commonly.
If it better fits your use case, you can also use the functional form,
torch.nn.functional.linear()
.
Consider:
>>> import torch
>>> torch.__version__
'1.12.0'
>>> _ = torch.manual_seed (2022)
>>> s = ((torch.randn (3, 5) > 0.0) * torch.randn (3, 5)).to_sparse() # make a sparse tensor
>>> s
tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2],
[0, 1, 2, 3, 2, 3, 2, 3, 4]]),
values=tensor([-1.0556, -0.5344, 1.2262, 1.1603, -0.1418, -0.3402,
0.6979, -0.5543, -0.2561]),
size=(3, 5), nnz=9, layout=torch.sparse_coo)
>>> lin = torch.nn.Linear (5, 3) # make a (non-sparse) Linear
>>> t = lin (s)
>>> t
tensor([[-0.4231, -0.6271, -0.6817],
[-0.2014, -0.0587, 0.1683],
[-0.3242, 0.1646, 0.3315]], grad_fn=<AddmmBackward0>)
>>> t.sum().backward() # backward works
>>> lin.weight.grad
tensor([[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561]])
>>> w = torch.nn.Parameter (lin.weight.detach().clone())
>>> b = torch.nn.Parameter (lin.bias.detach().clone())
>>> t = torch.nn.functional.linear (s, w, b) # functional form of linear also works
>>> t.sum().backward()
>>> w.grad
tensor([[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561],
[-1.0556, -0.5344, 1.7823, 0.2658, -0.2561]])
Best.
K. Frank