Force a linear layer's weights to be a diagonal matrix

I have seen similar questions but without well defined solutions, but I want to force a layer’s weight matrix to be diagonal. Is there a way I can do this?

One workaround could be to simply set the off-diagonals to zero before each forward pass through the network. Would this work? I’m uncertain if it will cause issues with the backprop.

I ended up setting up that linear layer to contain a vector of weights rather than a matrix. Then when it comes time to apply the vector I simply do torch.diag_embed(vector) (which is differentiable) to get the diagonal matrix. This way the weight matrix is only ever diagonal.