Custom backward pass

Hi everyone,

I have a neural network that contains some complex operations (e.g., f3 and f6 in this figure). In the forward pass, we use all the black arrows (use all f1, f2, …, f7), however, in the backward pass, instead of using f3 and f6, we want to use g3 and g6, which are pretty close f3 and f6 but are simpler.

Both f3 and f6 contain multiple layers and operations, so I define them in another class:

class CustomLayer_F(nn.Module):
    def __init__(self):
        super(CustomLayer_F, self).__init__()
           # defining all variables and operations
        
    def forward(self, input):
        # computing the forward pass

And for g3 and g6, I defnie:

class CustomLayer_G(nn.Module):
    def __init__(self):
        super(CustomLayer_G, self).__init__()
           # defining all variables and operations, different from F
        
    def forward(self, input):
        # computing the forward pass, different from F

In order to use CustomLayer_F in the forward pass and CustomLayer_G's gradients in the backward pass, I want something like:

class CustomForwardBackward(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
          # do whatever we do in CustomLayer_F

    @staticmethod
    def backward(ctx, grad_output):
        # forward through CustomLayer_G and use the gradients of G instead of F
        # However, I don't know how to use grad_output here!!!

Does anyone have any clue how can I use CustomForwardBackward to forward like CustomLayer_F but backward through CustomLayer_G and use its gradients?

Thank you very much!

Cross-posting from Twitter:
Check out this post and see, if it’s applicable for this use case.

1 Like

Thanks for the link and the discussion on twitter! It was actually helpful, however, a simpler solution that worked for me was this:

class CustomForwardBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
          with torch.enable_grad():
              output = ctx.g3.forward(input)
              ctx.save_for_backward(input, output)
          return ctx.f3.forward(input)
    @staticmethod
    def backward(ctx, grad_output):
          input, output = ctx.saved_tensors
          output.backward(grad_output, retain_graph=True)
          return input.grad
2 Likes

Hi Sadegh,

I was wondering what the “ctx” does actually here? I’ve tried to implement something exactly like what you’ve implemented, but instead of calling f3 and g3 functions as ctx objects, I called them as their own classes’ objects, exactly as what you defined in your first post. Just like this: (ClassF contains desired functions for forward propagation and ClassB contains desired functions for backward propagation).

class ClassF(nn.Module):
def init(self):
super(ClassF, self).init()

def forward(self, input1, input2):

    return something

class ClassB(nn.Module):
def init(self):
super(ClassB, self).init()

def forward(self, input1, input2):

    return something different than ClassF.forward()

class ForwardBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2):
with torch.enable_grad():
output = ClassB.forward(input1, input2)
ctx.save_for_backward(input1, output)
return ClassF.forward(input1, input2)
@staticmethod
def backward(ctx, grad_output):
input1, output = ctx.saved_tensors
output.backward(grad_output, retain_graph=True)
return input1.grad, input2.grad

When running the whole code with this, I got the error which tells that ClassB.forward() is missing 1 required positional argument: “input2”. It seems that the output = ClassB.forward(input1, input2) also counts the “self” as the first input! Could you please help me with this issue? Thanks!Preformatted text