Custom autograd doesn't learn as good

i’m using a customized autograd function for convolutional layers as shown in the code snippet

class conv_autograd(Function):

    def forward(ctx, input, weight, stride, groups, padding, dilation, bias = None, mask = None):
        ctx.save_for_backward(input, weight, bias, mask)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        if mask == None:
            conv_out = F.conv2d(input, weight, stride = stride, groups = groups, padding = padding, dilation = dilation, bias = None)
        else:
            conv_out = F.conv2d(input, weight*mask, stride = stride, groups = groups, padding = padding, dilation = dilation, bias = None)
        return conv_out
    
    def backward(ctx, grad_output):
        input, weight, bias, mask = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        dilation = ctx.dilation
        groups = ctx.groups

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = F.grad.conv2d_input(input.shape, weight, grad_output, stride=stride, padding=padding, dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            grad_weight = F.grad.conv2d_weight(input, weight.shape, grad_output, stride=stride, padding=padding, dilation=dilation, groups=groups)
            if mask is not None:
                grad_weight *= mask
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=(0, 2, 3))

        return grad_input, grad_weight, grad_bias, None, None, None, None, None

And the goal here is to train the weights of the convolutional layer using the fixed input mask. However even when i am not using the mask (mask = None) the model’s accuracy does not cross 38%. While using ResNet50 on CIFAR-100 dataset… What am i doing wrong?

You can use torch.autograd.gradcheck.gradcheck — PyTorch 2.5 documentation to validate that its producing the correct gradients for a given set of inputs.

1 Like

I think gradcheck is only for if it is actually producing any gradients at all? How would you validate if it is producing the correct gradients just through gradcheck?

For a given set of inputs + your function, gradcheck will compute the gradients once with your provided backward and once numerically via finite differencing. If the results are the same, then we can be confident that given that the forward is correct, the backward would also be correct.
This should also tell you if one the gradients wrt the inputs is unexpectedly zero.