My code need to customize a backward function using autograd.Function, but if I use fx.symbolic_trace to transform the model, I find my customized backward function is not traced.
How can I get around this?
For example:
import torch
from torch.fx import symbolic_trace
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
class ActFun(torch.autograd.Function):
@staticmethod
def forward(ctx,input):
output = input.gt(0.0).float()
return output
@staticmethod
def backward(ctx,grad_output):
return grad_output
self.act = ActFun()
self.func = self.act.apply
def forward(self, x):
return self.func(x)
mfunc = TestModule()
mfunc = symbolic_trace(mfunc)
a = torch.ones(1, requires_grad=True)
b = mfunc(a)
b.backward()
print(b, a.grad)