Hi All,
I just wanted to ask a brief follow-up question on torch.jit.script
being applied to torch.custom.autograd
functions. This has been asked before, and the answer’s been it’s not currently supported.
Is there any update on this? I did see this issue here, https://github.com/pytorch/pytorch/issues/22329.
I did find this artcile here as a potential solution. Is this the only way to script custom function within PyTorch as of 1.10?
Thank you!
import torch
import torch.nn as nn
from torch.autograd import Function
class Square(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_y):
x, = ctx.saved_tensors
return SquareBackward.apply(x, grad_y)
class SquareBackward(Function):
@staticmethod
def forward(ctx, x, grad_y):
ctx.save_for_backward(x, grad_y)
return grad_y * 2 * x
@staticmethod
def backward(ctx, grad_grad_x):
x, grad_y = ctx.saved_tensors
return 2 * grad_y * grad_grad_x, 2 * x * grad_grad_x
class CustomFunc(nn.Module):
def __init__(self):
super(CustomFunc, self).__init__()
def forward(self, x):
return Square.apply(x)
func = CustomFunc()
x = torch.randn(10)
y=func(x)
print("x: ",x)
print("y: ",y)
print("x^2: ",x.pow(2)) #check it all works...
jit_func = torch.jit.script(func) #jit function, fails here