I’m trying to force a linear layer to only have non-zero weights along the diagonal. Is there a straightforward way to do this?

Assuming that `x`

is your input, and that your output `y`

(and thus the resulting loss) is a function of `x.W + b`

, you can force `W`

to be diagonal by adding a constraint to your loss function: `normal_loss + sum({norm(W_ij), i!=j})`

, (regularization on all elements of `W`

except those on the diagonal).

roll your own module, and change computations as you need, for example you can base them on torch.diag(F.softplus(raw_W_diag_param)) (add inverse cholesky decomposition if you want off-diagonals).