I have a project that requires me to define a conv2d operation using a linear layer (will be replaced with a custom layer later). My first though was to check the logic of it and implement it using the vanilla nn.Linear layer.
class CustomConv(nn.Module):
"""
2D convolution over an input of shape (batch x channel x H x W)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
if isinstance(stride, int):
self.stride = (stride, stride)
else:
self.stride = stride
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
else:
self.kernel_size = kernel_size
if isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
if isinstance(dilation, int):
self.dilation = (dilation, dilation)
else:
self.dilation = dilation
self.linearized_kernel = torch.nn.Linear(self.kernel_size[0] * self.kernel_size[1]*self.in_channels, self.out_channels, bias=bias)
self.unfold = torch.nn.Unfold(kernel_size=(self.kernel_size[0], self.kernel_size[1]), dilation=dilation, padding=padding, stride=stride)
def forward(self, x):
bsz = x.shape[0]
h, w = x.shape[2:4]
h_out = math.floor((h + 2*self.padding[0] - self.dilation[0]*(self.kernel_size[0] - 1) - 1)/self.stride[0] + 1)
w_out = math.floor((w + 2*self.padding[1] - self.dilation[1]*(self.kernel_size[1] - 1) - 1)/self.stride[1] + 1)
patches = self.unfold(x)
patches = patches.view(bsz, -1, self.kernel_size[0] * self.kernel_size[1]*self.in_channels)
# perform the matrix multiplication
patches = self.linearized_kernel(patches)
patches = patches.view(bsz, self.out_channels, h_out, w_out)
#patches = torch.nn.functional.fold(patches, (h_out, w_out), (1, 1))
# return in the expected shape
return patches
When training a simple CNN model using the custom conv, it learns but only achieves half the accuracy of using a normal conv2d layer. Is there any issue with the logic being used here?