Hello,
I am trying to bypass a classic forward pass of Conv2d. I am working with QuantConv2d in Brevitas which relies on Pytorch implementation of Conv2d layer. In the forward pass the weights are quantized and using these weights, the convolution output is calculated. In the backward pass the updates are performed on floating point weights.
Instead of regular Pytorch Conv2d implementation, I use my simulator which uses C++ and Numpy arrays to calculate the output. This breaks my computational graph and makes it impossible for the layer to learn. So I am implementing my custom autograd function which looks like this:
class My_Conv2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, quant_kernels, fp_kernels, bias = None, padding = 0, stride = 1, dilation = 1):
# Save for backward
ctx.save_for_backward(input, fp_kernels, bias) # save fp_kernels because these are the ones I want to update
ctx.padding = padding
ctx.stride = stride
ctx.dilation = dilation
# Call my simulator
out = F.conv2d(input[0], quant_kernels[0], bias, stride, padding, dilation)
out = acs_conv2(input, quant_kernels, padding,stride, dilation, bias = bias)
out = out + fp_kernels.sum() * 0 # this was suggested to try to connect fp_kernels to the output result for gradient propagation
if (bias is not None):
out = out + bias.sum() * 0
return out
@staticmethod
def backward(ctx, grad_output):
input, fp_kernels, bias = ctx.saved_tensors
padding = ctx.padding
stride = ctx.stride
dilation = ctx.dilation
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv2d_input(input.shape, fp_kernels, grad_output, stride, padding, dilation)
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv2d_weight(input, fp_kernels.shape, grad_output, stride, padding, dilation)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0,2,3))
return grad_input, None, grad_weight, grad_bias, None, None, None
This still doesn’t solve the problem. The weights are not updated throughout epochs. Could you please give me some guidance and help me see where I am mistaken? Any help or pointers for further reading would be very appreciated, thank you!