I would like to implement a layer with weight sharing, that performs the matrix multiplication
M @ x, with
M = [ [a, b] [c, a] ]
c are the unique weights.
How can I implement this layer?
I tried instantiating M as a
nn.Linear module, then doing
M.weight = M.weight. While this works for weight sharing, the forward pass complexity is still the same as a non-weight-sharing layer; same for the model size.
Any help would be greatly appreciated, thanks.