Custom 2d convolution accuracy issue

I’ve been trying to build a custom channel-wise convolution operation with multiplication and sum operation so that I can alter the bit-precision during the convolution process.
(I’ve checked several forum articles and all of them told me I should build it with unfold and fold functions.)

But when I replace the torch.nn.conv2d function with my code
There is no error but accuracy seems to degrade a lot
I’m completely new to python and pytorch, and only know basics about neural network so I can’t find what mistakes I made. :cry:

Can anyone please help? I would really appreciate any help.
Also, if there is any tips to save gpu usage I would also appreciate it very much, since my code seems to consume tons of memory so I had to cut the batch size and it’s taking forever to see the results. :sob:

class Conv2DFunctionCustom(Function):
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        ctx.save_for_backward(input, weight, bias)
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups

        zeropad = nn.ZeroPad2d(ctx.padding[0])
        batch_size = len(input)
        input_channel = len(input[0])
        input_size = len(input[0][0])
        ochannel = len(weight)
        ichannel = len(weight[0])
        kernel_size = len(weight[0][0])
        if padding:
            inp_pad = zeropad(input)
            inp_pad = input

        input_unfold = torch.nn.functional.unfold(inp_pad, kernel_size).transpose(1, 2)

        if input_channel < 256:
            input_unfold = input_unfold.unfold(2,input_channel,input_channel)[None,:]
            input_unfold = input_unfold.unfold(2,256,256)[None,:]

        weight_unfold = weight.view(weight.size(0), -1)

        if input_channel < 256:
            weight_unfold = weight_unfold.unfold(1,ichannel,ichannel)[:,None,None]
            weight_unfold = weight_unfold.unfold(1,256,256)[:,None,None]
        mac = input_unfold * weight_unfold
        mac = torch.sum(mac, dim=4)
        mac = torch.sum(mac, dim=3)
        out = torch.nn.functional.fold(mac, (input_size, input_size), (1, 1)).transpose(0,1)

        return out

    def backward(ctx, grad_output):

        input, weight, bias = ctx.saved_tensors
        stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
        grad_input = grad_weight = grad_weight = grad_input = grad_bias = grad_stride = grad_padding = grad_dilation = grad_groups = None
        if ctx.needs_input_grad[1]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
        if ctx.needs_input_grad[2]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)

        return grad_input, grad_weight, grad_input, grad_weight, grad_bias, grad_stride, grad_padding, grad_dilation, grad_groups

Did you compare the forward and backward implementation to the PyTorch implementations and did you see any issues?
The forward could be checked by calculating the abs().max() error of the outputs, while torch.autograd.gradcheck might be helpful to check the backward implementation.