Custom forward function of convolution layer

Hello everyone,

I’m trying to implement a custom convolution layer. However, I only want to custom the forward function, and leave the backward function unchanged.

Following the guidance from, I think the implementation would be something like

class MyConv2d(Function):

    def forward(ctx, input, weight):
        ctx.save_for_backward(input, weight)
        input_temp, weight_temp = f(input, weight)
        output = g(input_temp, weight_temp)
        return output
     def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        Compute backward with original inputs, 
        i.e., compute gradients of output directly to input and weight, 
        and ignore the functions in the middle (i.e., f(·), g(·))
        return grad_input, grad_weight

where f(·) and g(·) are custom functions for computing convolutions. One example would be to flatten the input, turning convolutions into matrix multiplications. In this case, f(·) could be unfold(·). As the backward path of unfold(·) is very slow, I don’t want this operation to be part of the computational graph. In other words, no matter what f(·) and g(·) are, I would like the autograd always treats this layer as output = nn.conv2d(input, weight), making the computational graph only include input, weight, and output without any middle variables and functions.

The pseudocode above requires customizing the gradient calculation of convolution. I tried several approaches, but none of them can reach the same speed as the original conv2d function. I’m wondering:

  1. How to implement the backward function of conv layer as fast as the original one in Pytorch?
  2. Otherwise, are there any other simple ways to handle this problem?

Please help and thanks!!

I noticed one way to solve this problem is to compute output=nn.functional.conv2d(input, weight) first, and then perform any custom computations to get an output_temp. Finally, replace the values of output with output_temp.

However, the drawback is that I have to compute convolution operations twice, which slows down the training a lot. Are there any other strategies available? Thanks!