I implemented the decomposition of Conv2d into im2col -> gemm -> col2im
as below, this is related to https://discuss.pytorch.org/t/how-to-use-custom-convolution-module-instead-of-torch-nn-conv2d/21458:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import math
class MyConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
super(MyConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
def forward(self, input):
h_in, w_in = input.shape[2:]
h_out = math.floor((h_in + 2*self.padding - self.dilation*(self.kernel_size-1)-1)/self.stride+1)
w_out = math.floor((w_in + 2*self.padding - self.dilation*(self.kernel_size-1)-1)/self.stride+1)
# x: [bs ksize num_sliding]
x = torch.nn.functional.unfold(input, kernel_size=self.kernel_size, padding=self.padding)
bs = input.shape[0]
ksize = self.in_channels*self.kernel_size*self.kernel_size
num_sliding = x.shape[2]
assert x.shape[1] == ksize
# x: [bs*num_sliding ksize]
x = torch.transpose(x, 1, 2).reshape(-1, ksize)
weight_flat = self.weight.view(self.out_channels, ksize)
x = torch.mm(x, weight_flat.t())
x = x.reshape(bs, num_sliding, self.out_channels)
x = torch.transpose(x, 1, 2)
x = torch.nn.functional.fold(x, output_size=[h_out, w_out], kernel_size=1, padding=0, dilation=self.dilation, stride=1)
return x
if __name__ == '__main__':
input = torch.rand(10, 32, 1, 1)
weight = torch.rand(64, 32, 3, 3)
conv2d = MyConv2d(32, 64, kernel_size=3, padding=1)
conv2d.weight.data.copy_(weight)
y = conv2d(input)
print(y.mean())
conv2d = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
conv2d.weight.data.copy_(weight)
y = conv2d(input)
print(y.mean())
It seems that the stride
and kernel_size
in torch.nn.functional.fold
is DIFFERENT from that in torch.nn.functional.unfold
, am I right ?
I’ve manually checked the results by standard Conv2d
and MyConv2d
and the results are the same.
I think it’s better to give some hints in the documentation to make it clear.