Weight freezing

Hello everyone,

I am new to pytorch and have a query about weight freezing. Suppose I have one linear layer, I want to initialize my weight matrix as a lower triangular matrix. Then during training I would like only the lower triangular part to get updated and the upper triangular to remain zero. I have created a lower triangular mask of the same size as the matrix and have tried using detach and requires_grad = False for the indices where the mask is 0. Like so:
model.weight.data[mask == 0].requires_grad = False

model.weight.data[mask == 0] = model.weight.data[mask == 0].detach()

However both these instances freezes the entire layer and it does not update. I have also tried manually providing elements at these indices to the optimizer but that also doesn’t seem to work.

Any help would be greatly appreciated. Thanks in advance!

You could directly multiply the mask with the weight.
Since the multiplication with mask acts as a multiplication with a constant, it would be factored into gradient as well and the gradient will be 0 for the positions where mask = 0.

import torch
import torch.nn as nn

class MyClass(nn.Module):
    def __init__(self) -> None:
        self.weight = nn.Parameter(torch.randn(5,5))
        self.mask = torch.triu(torch.ones_like(self.weight))

    def forward(self, x):
        out = (self.weight * self.mask).matmul(x)
        return out

if __name__ == '__main__':
    data = torch.randn(5,4)
    model = MyClass()
    out = model(data)
1 Like

This is a really great and simple solution, thank you very much. I am still not yet able to understand why this works as I am new to pytorch, specifically why you call out.sum() and then backward() on it.
I presume if I do use this model in a training loop with an optimizer and call optimizer.step() it would only change the weights of the upper triangular part as the mask created is an upper triangular matrix.
Just out of curiosity would providing only a slice of the weight matrix to the optimizer only update that slice? Thanks again.