I am new to pytorch and have a query about weight freezing. Suppose I have one linear layer, I want to initialize my weight matrix as a lower triangular matrix. Then during training I would like only the lower triangular part to get updated and the upper triangular to remain zero. I have created a lower triangular mask of the same size as the matrix and have tried using detach and requires_grad = False for the indices where the mask is 0. Like so:
model.weight.data[mask == 0].requires_grad = False
You could directly multiply the mask with the weight.
Since the multiplication with mask acts as a multiplication with a constant, it would be factored into gradient as well and the gradient will be 0 for the positions where mask = 0.
import torch.nn as nn
def __init__(self) -> None:
self.weight = nn.Parameter(torch.randn(5,5))
self.mask = torch.triu(torch.ones_like(self.weight))
def forward(self, x):
out = (self.weight * self.mask).matmul(x)
if __name__ == '__main__':
data = torch.randn(5,4)
model = MyClass()
out = model(data)
This is a really great and simple solution, thank you very much. I am still not yet able to understand why this works as I am new to pytorch, specifically why you call out.sum() and then backward() on it.
I presume if I do use this model in a training loop with an optimizer and call optimizer.step() it would only change the weights of the upper triangular part as the mask created is an upper triangular matrix.
Just out of curiosity would providing only a slice of the weight matrix to the optimizer only update that slice? Thanks again.