How to delegate to other layers in custom backward function?

I am trying to create a custom layer with a custom gradient.

This layer should effectively act like two normal nn.Linear() layers in parallel, where one of them is slightly modified.

Since I don’t want to reinvent the wheel and define the weights and biases for the contained linear layers myself, I’m wondering how I can use these two existing nn.Linear() layers directly in my custom backward() function.

This is what I have tried, but it doesn’t even give an error message, it just crashes python:

class CustomFunction(autograd.Function):

    def forward(ctx, input, main_layer, aux_layer):
        # Both feed-forward portions are just the result of applying the respective layer to the input
        output_main = main_layer(input)
        output_aux = aux_layer(input)
        ctx.save_for_backward(input, output_main, output_aux)
        return output_main, output_aux

    def backward(ctx, grad_main, grad_aux):
        input, output_main, output_aux = ctx.saved_tensors
        grad_input = None
        if ctx.needs_input_grad[0]:
            # Channel the gradient of the main output into the input
            # (it's supposed to just call the backward() function of the main_layer)
            grad_input = output_main.backward(grad_main)
            # (custom code will go here later to use grad_aux, once the bug above is resolved)
        return grad_input, None, None

class CustomModule(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        self.main_layer = nn.Linear(num_inputs, num_outputs)
        self.aux_layer = nn.Linear(num_inputs, num_outputs)

    def forward(self, input):
        output_main, output_aux = SelfAssessmentFunction.apply(input, self.main_layer, self.aux_layer)
        return output_main, output_aux