As far as I understand this layer each window position will have its own set of weights.
So if we are using an input of shape [1, 3, 24, 24] and out_channels=2, kernel_size=3, stride=1 without padding, we will have 3*3*3*2*(22*22) = 26136 weights. Is this correct?
If that’s the case, if created a quick and dirty implementation of this layer using unfold and a simple multiplication.
Note that I haven’t tested properly all edge cases, so feel free to report any issues you get with this implementation.
I don’t know, if we can somehow skip the computation of the spatial output shape, as I need it for the internal parameters.
Also, I guess the implementation is not the fastest one, but should be a starter:
from torch.nn.modules.utils import _pair
class LocallyConnected2d(nn.Module):
def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
super(LocallyConnected2d, self).__init__()
output_size = _pair(output_size)
self.weight = nn.Parameter(
torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2)
)
if bias:
self.bias = nn.Parameter(
torch.randn(1, out_channels, output_size[0], output_size[1])
)
else:
self.register_parameter('bias', None)
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
def forward(self, x):
_, c, h, w = x.size()
kh, kw = self.kernel_size
dh, dw = self.stride
x = x.unfold(2, kh, dh).unfold(3, kw, dw)
x = x.contiguous().view(*x.size()[:-2], -1)
# Sum in in_channel and kernel_size dims
out = (x.unsqueeze(1) * self.weight).sum([2, -1])
if self.bias is not None:
out += self.bias
return out
batch_size = 5
in_channels = 3
h, w = 24, 24
x = torch.randn(batch_size, in_channels, h, w)
out_channels = 2
output_size = 22
kernel_size = 3
stride = 1
conv = LocallyConnected2d(
in_channels, out_channels, output_size, kernel_size, stride, bias=True)
out = conv(x)
out.mean().backward()
print(conv.weight.grad)