Stop gradient computation for specifc weight in wt matrix

I am trying to develop a model for parameter estimation in which some parameters are allowed to train and some parameters are not allowed to train. e.g. In following matrix w_matrix = np.array([[w1,0.0,0.0,-w2],[w3,-w4,0.0,0.0],[0.0,w5,-w6,0.0]] all 0.0 are fixed weight and all w are allowed to update.

My attempt is as follows:

class NeuralNet(nn.Module):
    
    def __init__(self,param):
        super(NeuralNet, self).__init__() 
        param.append(0.0)
        self.gamma = torch.tensor(param[0])
        self.gamma = Variable(self.gamma,
                              requires_grad=True).type(dtype)
        self.alpha_xy = torch.tensor(param[1])
        self.aplha_xy = Variable(self.alpha_xy,
                                 requires_grad=True).type(dtype)
        self.beta_y = torch.tensor(param[2])
        self.beta_y = Variable(self.beta_y,
                               requires_grad=True).type(dtype)
        self.alpha0 = torch.tensor(param[3])
        self.alpha0 = Variable(self.alpha0,
                               requires_grad=True).type(dtype)
        self.alpha_y = torch.tensor(param[4])
        self.alpha_y = Variable(self.alpha_y,
                               requires_grad=True).type(dtype)
        self.alpha1 = torch.tensor(param[5])
        self.alpha1 = Variable(self.alpha1,
                               requires_grad=True).type(dtype)
        self.alpha2 = torch.tensor(param[6])
        self.alpha2 = Variable(self.alpha2,
                               requires_grad=True).type(dtype)
        self.alpha3 = torch.tensor(param[7])
        self.alpha3 = Variable(self.alpha3,
                               requires_grad=True).type(dtype)
        self.zero_fix_wt = torch.tensor(param[8])
        self.zero_fix_wt = Variable(self.zero_fix_wt,
                               requires_grad=False).type(dtype)
        
        # Production Matrix
        self.w_production = torch.tensor([[self.gamma, self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt],
                                         [self.beta_y, self.zero_fix_wt, self.zero_fix_wt self.zero_fix_wt],
                                         [self.zero_fix_wt, self.alpha0, self.zero_fix_wt, self.zero_fix_wt]])
        self.w_production = Variable(self.w_production.to(device)).type(dtype)
        self.w_production = nn.Parameter(self.w_production)
        
        # Degradation Matrix
        self.w_decay = torch.tensor([[-self.alpha1, self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt],
                                     [self.zero_fix_wt, -self.alpha2, self.zero_fix_wt, self.zero_fix_wt],
                                     [self.zero_fix_wt, self.zero_fix_wt, -self.alpha_y, self.zero_fix_wt]])
        self.w_decay = Variable(self.w_decay.to(device)).type(dtype)
        self.w_decay = nn.Parameter(self.w_decay)        
        
        # Cross-Talk Matrix
        self.w_cross_talk = torch.tensor([[self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt, -self.alpha_xy],
                                         [self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt],
                                         [self.zero_fix_wt, self.zero_fix_wt, self.zero_fix_wt, -self.alpha3]])
        self.w_cross_talk = Variable(self.w_cross_talk.to(device)).type(dtype)
        self.w_cross_talk = nn.Parameter(self.w_cross_talk)
        
        
        
    def forward(self,input): 
        xy_term = input[0][0] * input[0][2]        
        input_new = torch.cat((input,xy_term.view([-1,1])),1)
        input_new = input_new.to(device,non_blocking=True)        
        hidden_state = (self.w_production + self.w_decay + self.w_cross_talk).mm(torch.transpose(input_new,0,1))
        out = hidden_state.view([-1,3]) + input
        return (out, hidden_state)
    
   

# parameter initialization
init_params = [2.0,3.7,1.5,1.1,0.9,0.1,0.9,0.01]

# define model in GPU
model = NeuralNet(init_params)
model = model.to(device)

loss_function = nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,momentum = 0.9)

start = datetime.now()
loss_1 = []

for i in tqdm(range(epochs)):
    total_loss = 0
    
    for j in range(x_train.size(0)):
        optimizer.zero_grad()
        input = x_train[j:(j+1)]
        target = y_train[j:(j+1)]        
        
        input = input.to(device,non_blocking=True)
        target = target.to(device,non_blocking=True)        
        
        (pred, _) = model(input)
        loss = loss_function(pred,target)
        total_loss += loss        
        loss.backward()   
        optimizer.step()  
        print(model.w_production)
                    
                   
    
    # Early Stopping
    loss_1.append(total_loss.item())        
    if len(loss_1) > 1 and np.abs(loss_1[-2] - total_loss.item()) <= 0.10:
        print('Early Stopping due to dIfference in error is less than threshold (0.10)')
        print('Stopping training at %d epochs with total loss = %d' %(i,total_loss))
        break


    if i % 50 == 0:
        print("Epoch: {} loss {}".format(i, total_loss))
        
end = datetime.now()
time_taken = end - start
print('Execution Time: ',time_taken)

Here problem I am facing is that loss.backward() and optimizer.step() compute the gradient and adjust it for all weight even I have set requires_grad=False for some weights. This is happening because of requires_grad=False workes either on the complete matrix or not at all.

My question is how to tell model not to touch fixed weight and never compute the gradient for the specifically defined fixed weight for which requires_grad=False is mentioned.

I think if you register hook and zero out gradients for mask, then it will work, for example,
‘’’
def hook_fn(grad):
out = grad.clone()
out[mask] = 0
return out

layer.weight.register_hook(hook_fn)
‘’’

Thanks. Isn’t is similar to manually zeroing the gradient of selected weight. If this is so, I don’t think that this will work because 0.0 in weight matrix represent that link is not present but for model link is present with zero weight. These two are very much different aspect, one is constraint optimization aspect and another is model architectural aspect.

Upto my guess, this problem can be solved by the trick that

Stop loss.backward() to compute error at each computational nodes and implement .backward() only at permitted nodes so that optimizer.step() update only permitted weights and don’t update fixed weight

But now the question is how to customize loss.backward() or is there any other efficient and fast way?

My hypothesis of writing customized loss function to deal with customized model was wrong Then I tried something different.
I have figured out one solution using this reference (but again there are some issues in the model I am not able to understand why this is happening)

My codes are as follows:

class CustomizedForwardBackwardFunc(torch.autograd.Function):
    '''
        Autograd function which masks it's weight by 'mask'
    '''
    
    @staticmethod
    def forward(ctx, input, production_weight, decay_weight, 
                cross_talk_weight, bias, prod_mask, decay_mask, cross_mask):
        
        if prod_mask.size() == production_weight.size():
            production_weight = torch.mul(production_weight, prod_mask)
            decay_weight = torch.mul(decay_weight, decay_mask)
            cross_talk_weight = torch.mul(cross_talk_weight, cross_mask)
        else:            
            raise Exception('Please check your production mask matrix size')
            
        xy_term = input[0][0] * input[0][2]        
        input_new = torch.cat((input,xy_term.view([-1,1])),1)
        input_new = input_new.to(device,non_blocking=True)   
        
        production_dt = production_weight.mm(torch.transpose(input_new,0,1))
        decay_dt = decay_weight.mm(torch.transpose(input_new,0,1))
        cross_talk_dt = cross_talk_weight.mm(torch.transpose(input_new,0,1))
        hidden_state = production_dt - decay_dt - cross_talk_dt
        out = hidden_state.view([-1,3]) + input
        
        if bias is not None:
            out += bias.unsqueeze(0).expand_as(out)
            
        ctx.save_for_backward(input_new,production_weight,decay_weight,
                              cross_talk_weight,bias,prod_mask,decay_mask,
                              cross_mask)
        
        return out
    
    
    @staticmethod
    def backward(ctx, grad_out):
                
        input_new,production_weight,decay_weight,cross_talk_weight,bias,prod_mask,decay_mask,cross_mask = ctx.saved_tensors
        grad_input = grad_prod_wt = grad_decay_wt = grad_cross_wt = None
        grad_bias = grad_prod_mask = grad_decay_mask = grad_cross_mask = None        
        
        grad_out_new = torch.cat((grad_out,(grad_out[0][0] + grad_out[0][2]).view([-1,1])),1)
        
        if ctx.needs_input_grad[0]:            
            grad_input = (production_weight - decay_weight - cross_talk_weight).mm(grad_out_new)
        
        if ctx.needs_input_grad[1]:
            grad_prod_wt = grad_out.t().clone().mm(input_new)            
            grad_prod_wt = torch.mul(grad_prod_wt, prod_mask)

            grad_decay_wt = grad_out.t().clone().mm(input_new)
            grad_decay_wt = torch.mul(grad_decay_wt, decay_mask)

            grad_cross_wt = grad_out.t().clone().mm(input_new)
            grad_cross_wt = torch.mul(grad_cross_wt, cross_mask)
                        
#         if ctx.needs_input_grad[2]: #don't train for bias
#             grad_bias = (grad_prod_out + grad_decay_out + grad_cross_out).sum(0).squeeze(0)
            
        return grad_input, grad_prod_wt, grad_decay_wt, grad_cross_wt, grad_bias, grad_prod_mask, grad_decay_mask, grad_cross_mask
        
        
class NeuralNet(nn.Module):
    
    def __init__(self,param, prod_mask,decay_mask,cross_mask, bias=True):
        super(NeuralNet, self).__init__() 
        
        if isinstance(prod_mask, torch.Tensor):
            self.prod_mask = prod_mask.type(torch.float)
        else:
            self.prod_mask = torch.tensor(self.prod_mask,dtype=torch.float)
            
        self.prod_mask = nn.Parameter(self.prod_mask.to(device),requires_grad=False)
        
        if isinstance(decay_mask, torch.Tensor):
            self.decay_mask = decay_mask.type(torch.float)
        else:
            self.decay_mask = torch.tensor(self.decay_mask,dtype=torch.float)
            
        self.decay_mask = nn.Parameter(self.decay_mask.to(device),requires_grad=False)
        
        if isinstance(cross_mask, torch.Tensor):
            self.cross_mask = cross_mask.type(torch.float)
        else:
            self.cross_mask = torch.tensor(self.cross_mask,dtype=torch.float)
            
        self.cross_mask = nn.Parameter(self.cross_mask.to(device),requires_grad=False)
        
        if bias:
            self.bias = Variable(torch.zeros(self.prod_mask.shape[0])).type(dtype)
            self.bias = nn.Parameter(self.bias.to(device), requires_grad=False)
        else:
            self.register_parameter('bias',None)
        
        # param_order = [gamma,alpha_xy,beta_y,alpha0,alpha_y,alpha1,alpha2,alpha3]
        # Production Matrix
        self.w_production = torch.tensor([[param[0], 0.0, 0.0, 0.0],
                                          [param[2], 0.0, 0.0, 0.0],
                                          [0.0, param[3], 0.0, 0.0]])
        self.w_production = Variable(self.w_production).type(dtype)
        self.w_production = nn.Parameter(self.w_production.to(device),
                                         requires_grad=True)
        
        # Degradation Matrix
        self.w_decay = torch.tensor([[param[5], 0.0, 0.0, 0.0],
                                     [0.0, param[6], 0.0, 0.0],
                                     [0.0, 0.0, param[4], 0.0]])
        self.w_decay = Variable(self.w_decay).type(dtype)
        self.w_decay = nn.Parameter(self.w_decay.to(device),
                                    requires_grad = True)        
        
        # Cross-Talk Matrix
        self.w_cross_talk = torch.tensor([[0.0, 0.0, 0.0, param[1]],
                                          [0.0, 0.0, 0.0, 0.0],
                                          [0.0, 0.0, 0.0, param[7]]])
        self.w_cross_talk = Variable(self.w_cross_talk).type(dtype)
        self.w_cross_talk = nn.Parameter(self.w_cross_talk.to(device),
                                         requires_grad = True) 
               
         
        self.w_production.data = torch.mul(self.w_production.data, self.prod_mask)
        self.w_decay.data = torch.mul(self.w_decay.data, self.decay_mask)
        self.w_cross_talk.data = torch.mul(self.w_cross_talk.data, self.cross_mask)
        
        
        
    def forward(self,input): 
        
        return CustomizedForwardBackwardFunc.apply(input,self.w_production,
                                               self.w_decay,self.w_cross_talk,
                                               self.bias,self.prod_mask,
                                               self.decay_mask,
                                               self.cross_mask)
    
    

# params initialized  from lmfit
# init_params = [gamma,alpha_xy,beta_y,alpha0,alpha_y,alpha1,alpha2,alpha3]
# init_params = [2.2, 3.1, 1.1, 1.0, 0.79, 0.07, 0.5, 0.008]
init_params = [2.0,3.7,1.5,1.1,0.9,0.1,0.9,0.01]
prod_mask = torch.tensor([[1.0, 0.0, 0.0, 0.0],
                          [1.0, 0.0, 0.0, 0.0],
                          [0.0, 1.0, 0.0, 0.0]])

decay_mask = torch.tensor([[1.0, 0.0, 0.0, 0.0],
                           [0.0, 1.0, 0.0, 0.0],
                           [0.0, 0.0, 1.0, 0.0]])

cross_mask = torch.tensor([[0.0, 0.0, 0.0, 1.0],
                           [0.0, 0.0, 0.0, 0.0],
                           [0.0, 0.0, 0.0, 1.0]])

# define model in GPU
model = NeuralNet(init_params,prod_mask,decay_mask,cross_mask)
model = model.to(device)

loss_function = nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,momentum = 0.9)

start = datetime.now()
learning_rate = 1e-6
loss_1 = []

for i in tqdm(range(epochs)):
    total_loss = 0
    
    for j in range(x_train.size(0)):
        optimizer.zero_grad() 
        
#         for param in model.parameters():
#             if param.requires_grad == True:
#                 param.retain_grad()
#                 param.grad.data.zero_() #commented because of error of  
                                          #getting "None" from param.grad
        
        x1 = x_train[j:(j+1)]
        y1 = y_train[j:(j+1)]         
        x1 = x1.to(device,non_blocking=True)
        y1 = y1.to(device,non_blocking=True)                
        pred = model(x1)
               
        loss = loss_function(pred,y1)
        total_loss += loss               
        loss.backward()            
#         optimizer.step() # commented because we're using customized 
                           # gradient upgradation
        with torch.no_grad():
            for name, param in model.named_parameters():                
                if param.requires_grad == True:
                    param -= learning_rate * param.grad
                                        
                    '''
                        Print statement to cross-cross check the masked
                        gradradients
                    '''
#                     print(param.grad)
                    # check masked param.grad
#                     if np.array(param.grad.cpu()).size == np.array(prod_mask).size:
#                         print('--- epoch={}, loss={} ---'.format(t,loss.item()))
#                         print('↓↓↓masked weight↓↓↓')
#                         print(param)
#                         print('↓↓↓masked grad of weight↓↓↓')
#                         print(param.grad)
             
    
    # Early Stopping
    loss_1.append(total_loss.item())  
#     print(total_loss.item())
    if len(loss_1) > 1 and np.abs(loss_1[-2] - total_loss.item()) <= 0.01:
        print('Early Stopping due to dIfference in error is less than threshold (0.10)')
        print('Stopping training at %d epochs with total loss = %d' %(i,total_loss))
        break


    if i % 50 == 0:
        print("Epoch: {} loss {}".format(i, total_loss))
        
end = datetime.now()
time_taken = end - start
print('Execution Time: ',time_taken)

First I tried with learning_rate = 1e-3 loss was continuously increasing. Then I tried 1e-4 still the same model response. Then I tried 1e-5 then model loss decreases to 14 epochs but then started increasing. So I further lowered down the learning rate to 1e-6 then the model was performing well. I also used early stopping criteria of difference in error as 0.1

Here is the model prediction results:

My issues are as follows:

  1. With learning rate 1e-6 and early stopping criteria (i.e. the difference in epoch error less than 0.1) then error lowered up to 1880 and then model training stopped. (Ideally, it must go down to zero). Why??
  2. I lowered the early stopping condition to 0.01 with the same learning rate 1e-6 then the model performed well till 150 epochs (error at 150 epoch = 1880) and then error started increasing again.

I though lowering the error criteria will help the model but opposite is happening. My doubt is :

  1. Is there anything wrong in the model that error is not going below 1880.
  2. I am taking just one batch of all training samples. How does batch_size affect model training perofrmance.

@ uchida-takumi @ jamesproud @ Arvind_Subramaniam : What will you suggest. Thanks