Python builtin <built-in method apply of FunctionMeta object at 0x561db6e84610> is currently not supported in Torchscript:

Hi All,

I just wanted to ask a brief follow-up question on torch.jit.script being applied to torch.custom.autograd functions. This has been asked before, and the answer’s been it’s not currently supported.

Is there any update on this? I did see this issue here, https://github.com/pytorch/pytorch/issues/22329.

I did find this artcile here as a potential solution. Is this the only way to script custom function within PyTorch as of 1.10?

Thank you!

import torch
import torch.nn as nn
from torch.autograd import Function

class Square(Function):
  @staticmethod
  def forward(ctx, x):
    ctx.save_for_backward(x)
    return x ** 2
    
  @staticmethod
  def backward(ctx, grad_y):
    x, = ctx.saved_tensors
    return SquareBackward.apply(x, grad_y)
    
class SquareBackward(Function):
  @staticmethod
  def forward(ctx, x, grad_y):
    ctx.save_for_backward(x, grad_y)
    return grad_y * 2 * x

  @staticmethod
  def backward(ctx, grad_grad_x):
    x, grad_y = ctx.saved_tensors
    return 2 * grad_y * grad_grad_x, 2 * x * grad_grad_x
    

class CustomFunc(nn.Module):

  def __init__(self):
    super(CustomFunc, self).__init__()
    
  def forward(self, x):
    return Square.apply(x)
    
func = CustomFunc()

x = torch.randn(10)
y=func(x)

print("x: ",x)
print("y: ",y)
print("x^2: ",x.pow(2)) #check it all works...

jit_func = torch.jit.script(func) #jit function, fails here