I would like to implement a layer with weight sharing, that performs the matrix multiplication M @ x
, with
M = [
[a, b]
[c, a]
]
where a
, b
, and c
are the unique weights.
How can I implement this layer?
I tried instantiating M as a nn.Linear
module, then doing M.weight[1][1] = M.weight[0][0]
. While this works for weight sharing, the forward pass complexity is still the same as a non-weight-sharing layer; same for the model size.
Any help would be greatly appreciated, thanks.