I’m trying to train a model that uses a symmetric matrix for the linear layer and wonder how to efficiently implement the symmetric matrix in pytorch.
I have seen this approach, but I think it does not fulfill my needs since it introduces more trainable parameters than necessary, namely features * features
parameters instead of (features * (features+1) ) / 2
.
I came up with the following solution:
features = 3
num_weights = features * (features+1) // 2
weights = torch.randn(num_weights)
tri_mat = torch.zeros(features,features)
tri_idx = torch.tril_indices(features,features)
tri_mat[tri_idx[0], tri_idx[1]] = weights
symmetric_weight = torch.tril(tri_mat) + torch.tril(tri_mat, -1).t()
However, I suspect creating the temporary tri_mat
could be avoided. Any ideas on how to do this more elegantly/efficiently?