Locally connected layers

Are there any plans to provide optimized locally connected layers in Pytorch, like LocallyConnected1D/LocallyConnected2D in Keras (convolution without weight sharing)?

https://keras.io/layers/local/

8 Likes

As far as I understand this layer each window position will have its own set of weights.
So if we are using an input of shape [1, 3, 24, 24] and out_channels=2, kernel_size=3, stride=1 without padding, we will have 3*3*3*2*(22*22) = 26136 weights. Is this correct?
If that’s the case, if created a quick and dirty implementation of this layer using unfold and a simple multiplication.
Note that I haven’t tested properly all edge cases, so feel free to report any issues you get with this implementation.
I don’t know, if we can somehow skip the computation of the spatial output shape, as I need it for the internal parameters.
Also, I guess the implementation is not the fastest one, but should be a starter:

from torch.nn.modules.utils import _pair

class LocallyConnected2d(nn.Module):
    def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False):
        super(LocallyConnected2d, self).__init__()
        output_size = _pair(output_size)
        self.weight = nn.Parameter(
            torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2)
        )
        if bias:
            self.bias = nn.Parameter(
                torch.randn(1, out_channels, output_size[0], output_size[1])
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        
    def forward(self, x):
        _, c, h, w = x.size()
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh).unfold(3, kw, dw)
        x = x.contiguous().view(*x.size()[:-2], -1)
        # Sum in in_channel and kernel_size dims
        out = (x.unsqueeze(1) * self.weight).sum([2, -1])
        if self.bias is not None:
            out += self.bias
        return out


batch_size = 5
in_channels = 3
h, w = 24, 24
x = torch.randn(batch_size, in_channels, h, w)

out_channels = 2
output_size = 22
kernel_size = 3
stride = 1
conv = LocallyConnected2d(
    in_channels, out_channels, output_size, kernel_size, stride, bias=True)

out = conv(x)

out.mean().backward()
print(conv.weight.grad)
7 Likes

Have you gotten around testing this? Looking to use it

any plan to make it official ?

If enough people are using this layer, we could think about writing an official (and most likely more performant) version of it. Do you have any papers on this particular layer or other references, since I’ve just used the Keras docs to replicate it.

8 Likes

I guess this paper from facebook should provide enough motivation to implement LC layers.

It seems deepface (ensembles) is still the best one
https://cmusatyalab.github.io/openface/models-and-accuracies/

6 Likes

@ptrblck
Is there any update on official release?

3 Likes

Any update on this??

Any update ? I really want an official one with sufficient tests

I, too, want to use this. Any updates?

No updates, but PRs are welcome. :slight_smile:
Would anyone of you be interested in working on it? CC @ming_li

@ptrblck
It seems there already an PR since 2017, and seems many people needs this function, but that PR is still under review, can you take a look at it and what is the problem there? Thank you very much:)

It seems the PR was abandoned and would most likely need a rewrite as the backend methods could have changed, so feel free to post your interest in the PR.

Hello!
Are there any plans in the near future to make it official?

I would also love to use the official version of a PyTorch LocallyConnected2D layer!

Hi, I am a master student in cognitive neuroscience, trying to implement a free convolutional/local neural network in pytorch. I have been trying to use this layer to build a simple 3 layer network to at first classify MNIST (later goal is to solve face/scene recognition and a combination of face+scene recognition). However, I have been struggling to get a decent performance with this network:

from torch import nn
from custom.LocallyConnected2d import LocallyConnected2d

class FreeConvNetwork(nn.Module):
    def __init__(self):
        super().__init__() # call upon parent class constructor ie their attributes and methods 
        self.LL1 = LocallyConnected2d(1, out_channels=16, output_size=13, kernel_size=3, stride=2, bias=True)
        # output_size = ((input size(width/height of img) - kernel_size + 2xpadding) / stride) + 1(for bias)
        # weights = (kernel_size*kernel_size*in_channels+1(for bias))*out_channels*output_size*output_size
        
        self.LL2 = LocallyConnected2d(16, 32, 6, 3, 2,True) 
        self.LL3 = LocallyConnected2d(32, 64, 4, 3, 1,True) 
        
        self.flatten = nn.Flatten()

        self.linear1 = nn.Linear(64*4*4, 512)
        self.linear2 = nn.Linear(512, 10)
        self.activation = nn.ReLU()
        self.softm = nn.Softmax()



    def forward(self, x):
        output = self.LL1(x)
        output = self.activation(output)
        output = self.LL2(output)
        output = self.activation(output)
        output = self.LL3(output)
        output = self.activation(output)

        output = self.flatten(output)
        output = self.linear1(output)
        output = self.activation(output)
        output = self.linear2(output)
        output = self.softm(output)
        return output

Using the layer implemented in keras I was able to get the network to perform. With this step I am recreating this paper without data augmentation: GitHub - Learning-In-The-Machine/Weight-Sharing: The source code associated with the paper, Learning in the Machine: To Share or not to Share? . The reason why I am trying to get this to work in pytorch, is that I consider this a learning experience and my supervisors suggested this would be better for the overall project.
Also when using my training loop with a normal CNN, everything works fine. Therefore, my conclusion was that the issue must lie with the custom implementation of the local layer or am I missing something?

If anybody has some suggestions, tips or help it would be greatly appreciated! (Also tipps on how to debug my code or where to start. Already checked whether the data propagates correctly through the DNN, which does as far as I can tell.) Also keep in mind that I am not too knowledgeable when it comes to pytorch etc so what might seem basic for you might be unknown to me:)

You could try to debug the issue by comparing the outputs of the custom PyTorch implementation with the Keras layer. Make sure to load the same parameters into both layers (maybe starting with a plain nn.Linear module might be a good idea to understand if the memory layouts are different etc.).

1 Like

@ptrblck Are there any updates on this? The GitHub issue seems to be pretty dead, even though interest in this functionality remains high. New research in NeuroAI also shows the promise of these layers over simple convolutions in explaining certain properties that are observed in the visual system ([2308.09431] End-to-end topographic networks as models of cortical map formation and human visual behaviour: moving beyond convolutions). It would be great for all of us Pytorch users to have an official implementation that lets us play with these ideas

1 Like

Hi, @Lasse have you found the reason? maybe overfit for simple MNIST? try some dropout?

I would also like to see the official implementation of this @ptrblck