Hello everyone, I’m trying to use an implemented layer that uses torch.autograd.Function.
I get the error mentioned in the headline:
“RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.”
I tried to update with the @staticmethod
The layer is implemented as follows:
I still get the error. I’m not sure why. I looked at the example they turn to, and looked at some other threads here around the same problem, but couldn’t fix this.
Hi @gil_bc,
when i ran your code i got a different error message, because you only return the gradient for one of your input tensors.
I got the following running, but its just a normal matrix multiplication. I am also not sure what type your sparse input is, so if its a normal Tensor or SparseTensor. I tested mine only with normal Tensors.
class SparseMM(torch.autograd.Function):
@staticmethod
def forward(ctx, sparse, dense):
ctx.save_for_backward(sparse, dense)
return torch.mm(sparse, dense)
@staticmethod
def backward(ctx, grad_output):
sparse, dense = ctx.saved_tensors
grad_sparse = grad_dense = None
if ctx.needs_input_grad[0]:
grad_sparse = torch.mm(grad_output, dense.t())
if ctx.needs_input_grad[1]:
grad_dense = torch.mm(sparse.t(), grad_output)
return grad_sparse, grad_dense
S = torch.randn(3, 2, requires_grad=True)
D = torch.randn(4, 2, requires_grad=True)
sparse_mm = SparseMM.apply
x = sparse_mm(S, D.t())
x.backward(torch.ones_like(x))
Hi @Caruso,
thank you for you answer!
I am new to pytorch and trying to use this code, which is a bit above my current level of understanding. I really appreciate your help.
This code below used to work until recently, with only a warning for the @staticmethod change. now it gives an error and doesn’t work anymore.
I tried the way you implemented it, but I still get the same error- the autograd style error.
I tried several other things I saw other people did… but I still can’t make it run
Looking at the original code, do you know how to adjust it to the new style?
Well, I managed to make this run, though not by fixing the error I had.
I just didn’t use the SparseMM function at all, instead I used this:
torch.sparse.mm(D, x.t()).t()