So I’m trying to implement a layer that performs weight sharing on a per-neuron basis.
I’ll give a basic example: We have 2 scalar parameters, a
and b
, and we have 3 dimensional inputs and outputs, and we want the weight matrix to be given by:
a b b
b a b
b b a
One way I came up with is this:
class WeightSharing(nn.Module):
def __init__(self):
super().__init__()
shared_weights = torch.nn.Parameter(torch.randn(2))
self.register_parameter("shared_weights",shared_weights)
self.index_map = torch.LongTensor([0,0,1, 0,1,0, 1,0,0])
def forward(self, x):
W = self.shared_weights[self.index_map].view(3, 3)
return torch.matmul(x, W)
The problem is that indexing the shared weights on each forward pass seems to be detrimental to performance. Can anyone recommend a better approach to this? Is there a layer that can perform a view based on a fixed index mapping?