As a non-running but conceptual example I am providing the following function.
I am trying to update that network’s parameters using default autograd but only with respect to that function’s i/o.
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x, net: nn.Module):
ctx.save_for_backward(x)
x=net(x.detach())
return x**2
@staticmethod
def backward(ctx, grad_out):
(x,) = ctx.saved_tensors
net.backward(x)
x=net(x)
return grad_out * 2 * x