Custom gradient with different forward and backward paths

Hi everyone,

I want to implement a ML optimization while having different forward and backward function for one module in my computational graph. I got the intuition from this link, but when running the code, gradients can be passed for the module but the loss won’t be decreased (converged). I bring a scheme of my code in the following for more details:

class Forward_function(torch.nn.Module):
    def __init__(self):
        super(Forward_function, self).__init__()

    def forward(self, input1, input2):
        return output_forward

class Backward_function(torch.nn.Module):
    def __init__(self):
        super(Backward_function, self).__init__()

    def forward(self, input1, input2):
        return output_backward

Here the forward function is not differentiable (cannot pass gradients) and we intend to use the backward function for the backpropagation. So, we want to have the “output_forward” as the output of forward path, and take gradients from “output_backward” in the backward path. Hence, the final module would be:

forward_class = Forward_function()
backward_class = Backward_function()

class Forward_Backward(torch.autograd.Function):
    def forward(ctx, input1, input2):
        with torch.enable_grad():
            output = backward_calss.forward(input1, input2)
            ctx.save_for_backward(output, input1, input2)
        return forward_calss.forward(input1, input2)
    def backward(ctx, grad_output):
        output, input1, input2, = ctx.saved_tensors
        input1_grad = torch.autograd.grad(output, input1, grad_output)
        input2_grad = torch.autograd.grad(output, input2, grad_output)
        return input1_grad[0], input2_grad[0]

For more details, input1_grad and input2_grad are tuples and their gradients lies in their first index. Also, for computing gradients we chose to use “torch.autograd.grad” instead of “torch.autograd.backward”, since in the middle of backwarding from loss function (when calling loss.backward), output.backward function calculates the gradients for all leaf variable again from the “Forward_Backward” module to the beginning of the graph. It seems like we are calling another .backward function in the middle of loss.backward, and it seems to be wrong to proceed in this way.

Does anyone find what is the problem with my code? Could you please help me to get rid of this problem?

Thanks a lot for your help in advance!