I need to do a multiplication with a fixed sparse matrix. Since the matrix is fixed, i don’t need the gradient wrt to that matrix, but only wrt the other matrix. Since there’s no autograd for sparse matrices yet, I implemented it like this:
class LeftMatMulSparseFixedWeights(torch.autograd.Function): """ Implementation of matrix multiplication of a Sparse Variable with a Dense Variable, returning a Dense one. This is added because there's no autograd for sparse yet. No gradient computed on the sparse weights. """ def forward(self, sparse_weights, x): self.save_for_backward(sparse_weights) return torch.mm(sparse_weights, x) def backward(self, grad_output): sparse_weights, = self.saved_tensors return None, torch.mm(sparse_weights.t(), grad_output)
This “works”, but I get into trouble anyway:
If I create one function object in the enclosing Module, like so:
class FixedSparseLinMod(nn.Module): """ A module that reads a sparse matrix from a file and does the left matrix multiplication. Typical usage is a terms-class matrix for zero-shot learning. """ def __init__(self, sparse_mat_file): super(FixedSparseLinMod, self).__init__() dims, inds, vals = read_sparse_tensor(sparse_mat_file) i = torch.LongTensor([[x for x in inds], [x for x in inds]]) v = torch.FloatTensor(vals) s = torch.Size([len(dims), len(dims)]) self.sparse_mat = nn.Parameter(torch.sparse.FloatTensor(i, v, s), requires_grad=False) self.matmul = LeftMatMulSparseFixedWeights()
and then use the following in the forward pass:
def forward(self, x): return self.matmul(self.sparse_mat, x.t()).t()
Then I get this:
in backward sparse_weights, = self.saved_tensors RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. Please specify retain_variables=True when calling backward for the first time.
I also confirmed that forward and backward work fine for the first batch, but the backward pass throws for the second batch. Ok, so looks like the shared self.matmul is cleaning up some state in a weird way after each pass. So my first question is: “What’s getting freed here? Why isn’t the sparse_weights state saved correctly during the second forward?”
I tried working around it by simply creating a new LeftMatMulSparseFixedWeights function object each mini-batch, by doing this in the enclosing Module:
def forward(self, x): return LeftMatMulSparseFixedWeights()(self.sparse_mat, x.t()).t()
This works, but creates a memory leak, which causes the process to crash because I run out of CUDA memory after some time.
I’m thinking the first approach is the correct one, but I can’t find a way to not get that error. Even for the second approach, though, i don’t really understand the memory leak: Isn’t the function object automatically freed?
Help very much appreciated!