Implement models that include gradients

Thank you for these hints! I tried to fill in the blanks an came up with this:

class RNNstep(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x_in, h_in):
    with torch.enable_grad():
      x_in.requires_grad = True
      h_in_copy = h_in.detach().clone().requires_grad_()
      y_out = F(x_in, h_in_copy)
      h_out, = torch.autograd.grad(
          y_out.sum(dim=0),
          h_in_copy,
          create_graph=True)

    ctx.save_for_backward(x_in, h_in_copy, y_out, h_out)

    return y_out.detach(), h_out.detach()

  @staticmethod
  def backward(ctx, grad_y_out, grad_h_out):
    x_in, h_in_copy, y_out, h_out = ctx.saved_tensors

    grad_x_in1, = torch.autograd.grad(y_out, x_in, grad_y_out, retain_graph=True)
    grad_x_in2, = torch.autograd.grad(h_out, x_in, grad_h_out, retain_graph=True)

    grad_h_in1, = torch.autograd.grad(y_out, h_in_copy, grad_y_out, retain_graph=True)
    grad_h_in2, = torch.autograd.grad(h_out, h_in_copy, grad_h_out, retain_graph=True)

    return grad_x_in1 + grad_x_in2, grad_h_in1 + grad_h_in2

Although this code runs, I am quite sure that it is not correct.
For example: Do the gradients w.r.t. the parameters of F get accumulated at all?

You spoke of snippets which I could find somewhere in this forum. Can you be a bit more specific what I should search for? The closest thing I found was this.

I am not quite sure what you mean by “plug output & gradients into the outer graph”.

I tried to learn a lot about torch.autograd today, but apparently I have not yet figured it out really…