Layer for General Weight Sharing

So I’m trying to implement a layer that performs weight sharing on a per-neuron basis.

I’ll give a basic example: We have 2 scalar parameters, a and b, and we have 3 dimensional inputs and outputs, and we want the weight matrix to be given by:

a b b
b a b
b b a

One way I came up with is this:

class WeightSharing(nn.Module):
    def __init__(self):
        super().__init__()
        shared_weights = torch.nn.Parameter(torch.randn(2))
        self.register_parameter("shared_weights",shared_weights)
        self.index_map = torch.LongTensor([0,0,1, 0,1,0, 1,0,0])
        
    def forward(self, x):
        W = self.shared_weights[self.index_map].view(3, 3)
        return torch.matmul(x, W)

The problem is that indexing the shared weights on each forward pass seems to be detrimental to performance. Can anyone recommend a better approach to this? Is there a layer that can perform a view based on a fixed index mapping?