Efficient Weight Sharing

I would like to implement a layer with weight sharing, that performs the matrix multiplication M @ x, with

M = [
    [a, b]
    [c, a]

where a, b, and c are the unique weights.

How can I implement this layer?
I tried instantiating M as a nn.Linear module, then doing M.weight[1][1] = M.weight[0][0]. While this works for weight sharing, the forward pass complexity is still the same as a non-weight-sharing layer; same for the model size.

Any help would be greatly appreciated, thanks.