I’m trying to force a linear layer to only have non-zero weights along the diagonal. Is there a straightforward way to do this?
Assuming that x
is your input, and that your output y
(and thus the resulting loss) is a function of x.W + b
, you can force W
to be diagonal by adding a constraint to your loss function: normal_loss + sum({norm(W_ij), i!=j})
, (regularization on all elements of W
except those on the diagonal).
roll your own module, and change computations as you need, for example you can base them on torch.diag(F.softplus(raw_W_diag_param)) (add inverse cholesky decomposition if you want off-diagonals).