How to learn only a subset of weights in gradient descent keeping the rest constant

i have implemented a custom loss function and learning weights using pytorch’s autograd system. in the weight matrix w of dimension 16*32 i want to only learn a subset of the weights say for instance w[0][0] ,w[0][1] in the first row and the remaining should be set to zero and w[1][2],w[1][3] in the second row and the remaining elements set to zero and so on for all the rows we have a fixed set of two weights i want to learn and i dont want to learn the remaining weights . i have read about the requires_grad flag but i have read in this post that we cant set it to a subset of elements .

  1. after learning the weights using gradient descent , is extracting only the required weights and setting all the remaining weights to zero useful?
  2. is setting the unrequired weights to 0 after every iteration of gradient descent a possible solution ?
    can you please suggest a way to achieve this ?

@tom explains a nice way to achieve this in this post and also mentions the shortcomings of other approaches such as zeroing out the gradients of some elements.

1 Like

@ptrblck thanks a lot for pointing out the link to that question sir . i am new to pytorch and i am finding the hint difficult to follow can you please help me understand the hint given by @tom sir . please excuse me if my question is very basic and naive . the below code is no way functional i just pieced together a psuedo code from @tom sir’s answer

 class custom loss:
         def __init__(self,mask,epochs=5000,learning_rate=0.1):
              self.register_buffer('weight_update_mask', the_mask)
    def forward(self, x):
             weight = torch.where(self.weight_update_mask,self.weight_param,self.weight_fixed)
             return torch.matmul(weight,x)

   def f_loss(Y,pred,w,lamb):
            return((1/Y.size()[0])*pred_loss+ lamb*reg)

  def fit(X_pt,Y_pt,w):
           w_pt=torch.tensor(w,requires_grad=True) # we are computing the derivative wrt w
           opt=torch.optim.Adam([w_pt], lr=self.learning_rate, betas=(0.9, 0.99), eps=1e-08, 
           weight_decay=0, amsgrad=False)
           for epoch in range(self.epochs):
               pred = torch.matmul(X_pt,w_pt)
               loss = f_loss(Y_pt,pred,w_pt,self.gamma,self.lamb) # loss value , a scalar
               loss.backward() # computes the derivative of the loss with respect to w
               opt.step()      # this performs an update to the parameters
               opt.zero_grad()  # this resets the gradients else they will accumulate up
           return w_pt

the following are my doubts :
1.should the_mask contain a tensor of 0’s and 1’s. 1’s at places with variables we want to compute the gradients for and 0’s in the places which we want to ignore gradient computation?
2.what is the significance of the self.weight_param,self.weight_fixed variables ? are these automatically computed by pytorch or should i define and initialise them in my init function ? if so how do i initialise them ?
3. what changes should be done to my fit function ? ( i dont expect code but please atleast explain in words )

for now I have attempted at solving my problem by making the un-necessary weights as 0 after every step of gradient descent using a mask matrix with 1’s at places with variables we want to compute the gradients for and 0’s in the places which we want to ignore gradient computation by element wise multiplying the weight matrix with it . it seems to have solved my problem but yet I don’t know if its accurate to do it this way .

Generally it is quite close to what I had in mind.

I’d make it a boolean tensor, so True for trained False for fixed. (Mind you, see how neatly the first letters align? Programming is poetry. Just some form that is modern enough to be entirely tedious.)
The self.weight_fixed should be a buffer, too.

self.weight_param should be a nn.Parameter and initialized like you would initialize plain parameters. Note that the part where the mask is False the parameter is ignored. The weight_fixed should be whatever you see fit. Here, the bits where the mask is True is ignored.