Decompose Conv2d input unfold-gemm-fold

I implemented the decomposition of Conv2d into im2col -> gemm -> col2im as below, this is related to https://discuss.pytorch.org/t/how-to-use-custom-convolution-module-instead-of-torch-nn-conv2d/21458:

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import math

class MyConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(MyConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))

    def forward(self, input):
        h_in, w_in = input.shape[2:]
        h_out = math.floor((h_in + 2*self.padding - self.dilation*(self.kernel_size-1)-1)/self.stride+1)
        w_out = math.floor((w_in + 2*self.padding - self.dilation*(self.kernel_size-1)-1)/self.stride+1)

        # x: [bs ksize num_sliding]
        x = torch.nn.functional.unfold(input, kernel_size=self.kernel_size, padding=self.padding)

        bs = input.shape[0]
        ksize = self.in_channels*self.kernel_size*self.kernel_size
        num_sliding = x.shape[2]

        assert x.shape[1] == ksize

        # x: [bs*num_sliding ksize]
        x = torch.transpose(x, 1, 2).reshape(-1, ksize)

        weight_flat = self.weight.view(self.out_channels, ksize)
        x = torch.mm(x, weight_flat.t())

        x = x.reshape(bs, num_sliding, self.out_channels)
        x = torch.transpose(x, 1, 2)
        x = torch.nn.functional.fold(x, output_size=[h_out, w_out], kernel_size=1, padding=0, dilation=self.dilation, stride=1)
        return x


if __name__ == '__main__':
    input = torch.rand(10, 32, 1, 1)
    weight = torch.rand(64, 32, 3, 3)

    conv2d = MyConv2d(32, 64, kernel_size=3, padding=1)
    conv2d.weight.data.copy_(weight)

    y = conv2d(input)

    print(y.mean())

    conv2d = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
    conv2d.weight.data.copy_(weight)

    y = conv2d(input)

    print(y.mean())

It seems that the stride and kernel_size in torch.nn.functional.fold is DIFFERENT from that in torch.nn.functional.unfold, am I right ?

I’ve manually checked the results by standard Conv2d and MyConv2d and the results are the same.

I think it’s better to give some hints in the documentation to make it clear.