Legacy autograd function with non-static forward method - old pytorch code

Trying to run some code from here. If I run the test.py after removing the require_gradients() on line 39 (which I’m not sure how to avoid the leaf variable error otherwise for) I still receive the following error message:


RuntimeError Traceback (most recent call last)
Input In [23], in
61 print(f"Solving K={K} linear systems that are {n} x {n} with {As[0].nnz} nonzeros and {m} right hand sides.")
63 cg = CG(A_bmm, M_bmm=M_bmm, rtol=1e-5, atol=1e-5, verbose=True)
—> 64 X = cg(B_torch)
66 start = time.perf_counter()
67 X_np = np.concatenate([np.hstack([splinalg.cg(A, B[:, i], M=M)[0][:, np.newaxis] for i in range(m)])[np.newaxis, :, :]
68 for A, B, M in zip(As, Bs, Ms)], 0)

File ~/.local/lib/python3.9/site-packages/torch/autograd/function.py:261, in Function.call(self, *args, **kwargs)
260 def call(self, *args, **kwargs):
→ 261 raise RuntimeError(
262 "Legacy autograd function with non-static forward method is deprecated. "
263 "Please use new-style autograd function with static forward method. "
264 “(Example: Automatic differentiation package - torch.autograd — PyTorch 1.10.1 documentation)”)

RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.


I’ve read some on this forum about this error, but cannot single out the problem in the definition of the forward, backward functions:

class CG(torch.autograd.Function):

def __init__(self, A_bmm, M_bmm=None, rtol=1e-3, atol=0., maxiter=None, verbose=False):
    self.A_bmm = A_bmm
    self.M_bmm = M_bmm
    self.rtol = rtol
    self.atol = atol
    self.maxiter = maxiter
    self.verbose = verbose

def forward(self, B, X0=None):
    X, _ = cg_batch(self.A_bmm, B, M_bmm=self.M_bmm, X0=X0, rtol=self.rtol,
                 atol=self.atol, maxiter=self.maxiter, verbose=self.verbose)
    return X

def backward(self, dX):
    dB, _ = cg_batch(self.A_bmm, dX, M_bmm=self.M_bmm, rtol=self.rtol,
                  atol=self.atol, maxiter=self.maxiter, verbose=self.verbose)
    return dB

The error message links to an example of a custom autograd.Function, which e.g. includes the @staticmethod decorators which are missing in your code.

@ptrblck thanks. when I add the @staticmethod bits in before the forward and backward definitions I still get the same error. What other decorators are important?