I’m using a class meant to implement a linear layer with a triangular weight matrix defined here:
class TriLinear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias = False)
with torch.no_grad():
self.weight.copy_(torch.tril(self.weight))
self.weight.register_hook(lambda grad: grad * torch.tril(torch.ones_like(grad)))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
I have to replace a few regular linear layers of an existing model with this class and compare the performance. However, in the larger model initialization there is the followeing code:
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
It seems that when this runs, the weight matrices of my TriLinear class are no longer triangular. Is there a way to implement this initialization without interfering with the triangular weight matrices?