Properly initializing triangular weight matrices

I’m using a class meant to implement a linear layer with a triangular weight matrix defined here:

class TriLinear(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features, bias = False)
        with torch.no_grad():
            self.weight.copy_(torch.tril(self.weight))
        self.weight.register_hook(lambda grad: grad * torch.tril(torch.ones_like(grad)))
        
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

I have to replace a few regular linear layers of an existing model with this class and compare the performance. However, in the larger model initialization there is the followeing code:

for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 

It seems that when this runs, the weight matrices of my TriLinear class are no longer triangular. Is there a way to implement this initialization without interfering with the triangular weight matrices?

Hi Big!

The most satisfactory approach would be to modify the xavier-initialization
to iterate over modules:

for  m in model.modules():
    if  instanceof (m, TriLinear):  continue
    # xavier-initialize the parameters in m

However for a simple tweak to your code that is exceedingly unlikely
to fail, try:

for p in self.parameters():
            if  p.dim == 2  and  torch.equal (p, p.tril()):  continue
            if  p.dim() > 1:
                nn.init.xavier_uniform_(p) 

Best.

K. Frank

1 Like

This is perfect, thank you.