How to create a custom convolutional autograd function

@Lucca had you any success with your implementation? I am trying to do exactly the same to experiment with sparsifying the gradient. I would be very happy if you could share your solution, given that you found one.

I have the forwards part working thanks to Custom convolution layer - PyTorch Forums

but not sure what to put in @backwards

class SparseConv2d(torch.autograd.Function):  
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, kernel_size, out_channels, dilation, padding, stride, info, bias=None):
        # Save inputs in context-object for later use in backwards
        ctx.save_for_backward(input, weight, bias) # these are differentiable
        
        # Save non-differentiable argument(s)
        ctx.info = info 
        
        
        # x [batchSize, in_channels, width, height]
        width = ((x.shape[2] + 2*padding[0] - dilation[0]*(kernel_size - 1) - 1) // stride[0]) + 1
        height= ((x.shape[3] + 2*padding[1] - dilation[1]*(kernel_size - 1) - 1) // stride[1]) + 1
     
        windows = F.unfold(x, kernel_size=(kernel_size, kernel_size), padding=padding, dilation=dilation, stride=stride)
        windows = windows.transpose(1, 2).contiguous().view(-1, x.shape[1], kernel_size*kernel_size)
        windows = windows.transpose(0, 1)
   
        
        output = torch.zeros([x.shape[0]*out_channels, width, height], dtype=torch.float32, device=device)

        # Loop over channels
        for channel in range(x.shape[1]):
            for outChannel in range(out_channels):
                res = torch.matmul(windows[channel], weight[outChannel][channel]) 
                res = res.view(-1, width, height)
                output[outChannel * res.shape[0] : (outChannel + 1) * res.shape[0]] += res
                
        output = output.view(x.shape[0], out_channels, width, height)
       
        #if bias is not None:
            #output += bias.unsqueeze(0).expand_as(output)
        return output
    
    
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
      
        
        # UNSURE WHAT TO PUT HERE
      
        return grad_input, grad_weight, None, None, None, None, None, None, grad_bias