Variation in weights between custom Conv2d and pytorch conv2d?

I am trying to build a custom convolution using the method shown in pytorch unfold function

The custom convolution function is given below:

 import torch
    from torch import nn
    import torch.nn.functional as F
    from torch.nn.parameter import Parameter
    import math
    from torch.nn.modules.utils import _pair
    
    
    class customConv(nn.Module):
        def __init__(self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1, bias=True):
            super(customConv, self).__init__()
    
            self.kernel_size = _pair(kernel_size)
            self.out_channels = out_channels
            self.dilation = _pair(dilation)
            self.padding = _pair(padding)
            self.stride = _pair(stride)
            self.n_channels = n_channels
            self.weight = Parameter(torch.Tensor(self.out_channels, self.n_channels, self.kernel_size[0], self.kernel_size[1]))
            if bias:
                self.bias = Parameter(torch.Tensor(out_channels))
            else:
                self.register_parameter('bias', None)
            self.reset_parameters()
    
        def reset_parameters(self):
            n = self.n_channels
            for k in self.kernel_size:
                n *= k
            stdv = 1. / math.sqrt(n)
            self.weight.data.uniform_(-stdv, stdv)
            if self.bias is not None:
                self.bias.data.uniform_(-stdv, stdv)
    
        def forward(self, input_):
    
            hout = ((input_.shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0]-1)-1)//self.stride[0])+1
            wout = ((input_.shape[3] + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1]-1)-1)//self.stride[1])+1
    
            inputUnfolded = F.unfold(input_, kernel_size=self.kernel_size, padding=self.padding, dilation=self.dilation, stride=self.stride)
            if self.bias:
                convolvedOutput = (inputUnfolded.transpose(1, 2).matmul(
                self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2)) + self.bias.view(-1, 1)
            else:
                convolvedOutput = (inputUnfolded.transpose(1, 2).matmul(self.weight.view(self.weight.size(0), -1).t()).transpose(1, 2))
            convolutionReconstruction = convolvedOutput.view(input_.shape[0], self.out_channels, hout, wout)
            return convolutionReconstruction

But when I try comparing it with the pytorch implementation, I do not get the exact value. The code to check for difference is provided below


    import torch
    from torch import nn
    from customConvolve import customConv   
    torch.manual_seed(1)
    
    input = torch.randn (10,3,64,64)
    
    conv1 = nn.Conv2d(input.shape[1],5, kernel_size=3, dilation=1, padding=1, stride=1 ,bias = False)
    conv1_output = conv1(input)
    
    conv2 = customConv(n_channels=input.shape[1], out_channels=5, kernel_size=3,  dilation=1, stride =1, padding = 1, bias = False)
    conv2_output = conv2(input)
    
    print(torch.equal(conv1.weight.data, conv2.weight.data))
    
    print(torch.equal(conv1_output, conv2_output))

I would like to know why the variation exists and how to solve this?
Thank you.

To me, your implementation of custom conv seems to be correct, except two things in testing code:

  1. In testing code, the two conv layers are not sharing the same weight. You can assign the weight of one conv layer to the other as follows:

    conv1 = nn.Conv2d(input.shape[1], 1, kernel_size=3, dilation=1, padding=1, stride=1 ,bias = False)
    conv2 = customConv(n_channels=input.shape[1], out_channels=1, kernel_size=3,  dilation=1, stride =1, padding = 1, bias = False)
    conv1.weight = conv2.weight
    conv1.bias = conv2.bias
  1. Once you assign the same weights, the conv outputs are very close (maybe only varies in numerical precision). You can check using L2 loss:
print(torch.pow(conv1_output - conv2_output, 2).sum())

Maybe, take a simple 5 x 5 input and verify the output by printing.

Yes, you are correct. I forgot to assign the same weights. The error is very low now. The variation is because of numerical precision.

Thank you so much.