I have four parameters ‘a’, ‘b’, ‘weight’, and, ‘bias’ defined in the PyTorch module.
In the forward pass,
Step 1: Creating a sparse matrix ‘T’ that is dependent on the parameters ‘a’ and ‘b’
Step 2: mat_1 = torch.mm (input, weight)
Step 3: output = SparseMM()(T, mat_1)
Step 4: return output + bias
where, SparseMM is an autograd function implemented as below:
class SparseMM(torch.autograd.Function):
“”"
Sparse x dense matrix multiplication with autograd support.
Implementation by Soumith Chintala:
https://discuss.pytorch.org/t/
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""
def forward(self, matrix1, matrix2):
self.save_for_backward(matrix1, matrix2)
return torch.mm(matrix1, matrix2)
def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors
grad_matrix1 = grad_matrix2 = None
if self.needs_input_grad[1]:
grad_matrix2 = torch.mm(matrix1.t(), grad_output)
return grad_matrix2
The weight and bias are the parameters that get updated, but the parameter ‘a’ and ‘b’ that create the sparse matrix doesn’t get updated.
Is there a way to update these parameters as well?