Thank you for these hints! I tried to fill in the blanks an came up with this:
class RNNstep(torch.autograd.Function):
@staticmethod
def forward(ctx, x_in, h_in):
with torch.enable_grad():
x_in.requires_grad = True
h_in_copy = h_in.detach().clone().requires_grad_()
y_out = F(x_in, h_in_copy)
h_out, = torch.autograd.grad(
y_out.sum(dim=0),
h_in_copy,
create_graph=True)
ctx.save_for_backward(x_in, h_in_copy, y_out, h_out)
return y_out.detach(), h_out.detach()
@staticmethod
def backward(ctx, grad_y_out, grad_h_out):
x_in, h_in_copy, y_out, h_out = ctx.saved_tensors
grad_x_in1, = torch.autograd.grad(y_out, x_in, grad_y_out, retain_graph=True)
grad_x_in2, = torch.autograd.grad(h_out, x_in, grad_h_out, retain_graph=True)
grad_h_in1, = torch.autograd.grad(y_out, h_in_copy, grad_y_out, retain_graph=True)
grad_h_in2, = torch.autograd.grad(h_out, h_in_copy, grad_h_out, retain_graph=True)
return grad_x_in1 + grad_x_in2, grad_h_in1 + grad_h_in2
Although this code runs, I am quite sure that it is not correct.
For example: Do the gradients w.r.t. the parameters of F get accumulated at all?
You spoke of snippets which I could find somewhere in this forum. Can you be a bit more specific what I should search for? The closest thing I found was this.
I am not quite sure what you mean by “plug output & gradients into the outer graph”.
I tried to learn a lot about torch.autograd today, but apparently I have not yet figured it out really…