So I’m trying to implement a layer that performs weight sharing on a per-neuron basis.
I’ll give a basic example: We have 2 scalar parameters,
b, and we have 3 dimensional inputs and outputs, and we want the weight matrix to be given by:
a b b b a b b b a
One way I came up with is this:
class WeightSharing(nn.Module): def __init__(self): super().__init__() shared_weights = torch.nn.Parameter(torch.randn(2)) self.register_parameter("shared_weights",shared_weights) self.index_map = torch.LongTensor([0,0,1, 0,1,0, 1,0,0]) def forward(self, x): W = self.shared_weights[self.index_map].view(3, 3) return torch.matmul(x, W)
The problem is that indexing the shared weights on each forward pass seems to be detrimental to performance. Can anyone recommend a better approach to this? Is there a layer that can perform a view based on a fixed index mapping?