Custom autograd.Function for quantized C++ simulator

Hello,

I am trying to bypass a classic forward pass of Conv2d. I am working with QuantConv2d in Brevitas which relies on Pytorch implementation of Conv2d layer. In the forward pass the weights are quantized and using these weights, the convolution output is calculated. In the backward pass the updates are performed on floating point weights.
Instead of regular Pytorch Conv2d implementation, I use my simulator which uses C++ and Numpy arrays to calculate the output. This breaks my computational graph and makes it impossible for the layer to learn. So I am implementing my custom autograd function which looks like this:

class My_Conv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, quant_kernels, fp_kernels, bias = None, padding = 0, stride = 1, dilation = 1):
        # Save for backward
        ctx.save_for_backward(input, fp_kernels, bias) # save fp_kernels because these are the ones I want to update
        ctx.padding = padding
        ctx.stride = stride
        ctx.dilation = dilation

        # Call my simulator
   
        out =  F.conv2d(input[0], quant_kernels[0], bias, stride, padding, dilation)
        
        out = acs_conv2(input, quant_kernels, padding,stride, dilation, bias = bias)

        out = out + fp_kernels.sum() * 0 # this was suggested to try to connect fp_kernels to the output result for gradient propagation
        if (bias is not None):
            out = out + bias.sum() * 0
        return out
    
    @staticmethod
    def backward(ctx, grad_output):
        input, fp_kernels, bias = ctx.saved_tensors
        padding = ctx.padding
        stride = ctx.stride
        dilation = ctx.dilation

        grad_input = grad_weight = grad_bias = None
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, fp_kernels, grad_output, stride, padding, dilation)
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, fp_kernels.shape, grad_output, stride, padding, dilation) 
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))

        return grad_input, None, grad_weight, grad_bias, None, None, None

This still doesn’t solve the problem. The weights are not updated throughout epochs. Could you please give me some guidance and help me see where I am mistaken? Any help or pointers for further reading would be very appreciated, thank you!

Did you check that the check the gradients are able to flow through this particular custom autograd.Function? e.g. by passing inputs that require grad directly to the function and calling backward on its outputs.

Hi, thanks for your help. I checked the gradients, and the problem was that the outputs of my simulator didn’t keep required_grad = True. Moreover, QuantTensors from Brevitas are not able to propagage gradients by default, so it was necessary to pass only the tensor part of QuatnTensor. This solved the issue, but you gave me the right direction, thank you! :slight_smile: