Locally connected layers

As far as I understand this layer each window position will have its own set of weights.
So if we are using an input of shape [1, 3, 24, 24] and out_channels=2, kernel_size=3, stride=1 without padding, we will have 3*3*3*2*(22*22) = 26136 weights. Is this correct?
If that’s the case, if created a quick and dirty implementation of this layer using unfold and a simple multiplication.
Note that I haven’t tested properly all edge cases, so feel free to report any issues you get with this implementation.
I don’t know, if we can somehow skip the computation of the spatial output shape, as I need it for the internal parameters.
Also, I guess the implementation is not the fastest one, but should be a starter:

from torch.nn.modules.utils import _pair

class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
        super(LocallyConnected2d, self).__init__()
        output_size = _pair(output_size)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2)
        )
        if bias:
            self.bias = nn.Parameter(
                torch.randn(1, out_channels, output_size[0], output_size[1])
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        
    def forward(self, x):
        _, c, h, w = x.size()
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh).unfold(3, kw, dw)
        x = x.contiguous().view(*x.size()[:-2], -1)
        # Sum in in_channel and kernel_size dims
        out = (x.unsqueeze(1) * self.weight).sum([2, -1])
        if self.bias is not None:
            out += self.bias
        return out


batch_size = 5
in_channels = 3
h, w = 24, 24
x = torch.randn(batch_size, in_channels, h, w)

out_channels = 2
output_size = 22
kernel_size = 3
stride = 1
conv = LocallyConnected2d(
    in_channels, out_channels, output_size, kernel_size, stride, bias=True)

out = conv(x)

out.mean().backward()
print(conv.weight.grad)
7 Likes