Custom backward pass

Hi everyone,

I have a neural network that contains some complex operations (e.g., f3 and f6 in this figure). In the forward pass, we use all the black arrows (use all f1, f2, …, f7), however, in the backward pass, instead of using f3 and f6, we want to use g3 and g6, which are pretty close f3 and f6 but are simpler.

Both f3 and f6 contain multiple layers and operations, so I define them in another class:

class CustomLayer_F(nn.Module):
    def __init__(self):
        super(CustomLayer_F, self).__init__()
           # defining all variables and operations
        
    def forward(self, input):
        # computing the forward pass

And for g3 and g6, I defnie:

class CustomLayer_G(nn.Module):
    def __init__(self):
        super(CustomLayer_G, self).__init__()
           # defining all variables and operations, different from F
        
    def forward(self, input):
        # computing the forward pass, different from F

In order to use CustomLayer_F in the forward pass and CustomLayer_G's gradients in the backward pass, I want something like:

class CustomForwardBackward(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
          # do whatever we do in CustomLayer_F

    @staticmethod
    def backward(ctx, grad_output):
        # forward through CustomLayer_G and use the gradients of G instead of F
        # However, I don't know how to use grad_output here!!!

Does anyone have any clue how can I use CustomForwardBackward to forward like CustomLayer_F but backward through CustomLayer_G and use its gradients?

Thank you very much!

Cross-posting from Twitter:
Check out this post and see, if it’s applicable for this use case.

1 Like

Thanks for the link and the discussion on twitter! It was actually helpful, however, a simpler solution that worked for me was this:

class CustomForwardBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
          with torch.enable_grad():
              output = ctx.g3.forward(input)
              ctx.save_for_backward(input, output)
          return ctx.f3.forward(input)
    @staticmethod
    def backward(ctx, grad_output):
          input, output = ctx.saved_tensors
          output.backward(grad_output, retain_graph=True)
          return input.grad
2 Likes