Hi,
I try to implement a new operator using torch.autograd.Function. Currently I have something organized like below, using free function to implement long computation. This is not the true computation of course, just to give the logic (fun1 detach and clone the data, fun2 just do some computation).
class MyOp(torch.autograd.Function):
@staticmethod
def forward(X):
out = fun1(X)
return out
@staticmethod
def setup_context(ctx, ...):
....
@staticmethod
def backward(ctx, grad):
return fun2(grad)
def fun1(X):
Y = X.detach().clone()
return Y
def fun2(X):
X = X**2
return X
I’m wondering if computation in fun1 and fun2 will be included in the computational graph or do I need to set-up no_grad context or something equivalent ? O
In the (excellent) documentation, I just found example where all computation are done in the forward or backward function.