I’m using a class meant to implement a linear layer with a triangular weight matrix defined here:
class TriLinear(nn.Linear): def __init__(self, in_features, out_features): super().__init__(in_features, out_features, bias = False) with torch.no_grad(): self.weight.copy_(torch.tril(self.weight)) self.weight.register_hook(lambda grad: grad * torch.tril(torch.ones_like(grad))) def forward(self, x): return F.linear(x, self.weight, self.bias)
I have to replace a few regular linear layers of an existing model with this class and compare the performance. However, in the larger model initialization there is the followeing code:
for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)
It seems that when this runs, the weight matrices of my TriLinear class are no longer triangular. Is there a way to implement this initialization without interfering with the triangular weight matrices?