Pre-Trained Model Weights Do Not Update During Fine-Tuning

I’m trying to apply DPO to a pre-trained model. The loss is computed based on scores returned by the finetuned and vanilla models. However, during the training process, the scores given by the pre-trained model and the fine-tuned model are identical, and the loss remains the same across all batches, leading me to believe the weights are not being updated. My training method is given below.

def train(model, optimizer, pref_set, dispref_set, epochs, beta, bs):
    model.train()
    #print(list(model.parameters())[0])
    #print(list(model.parameters())[0].grad)
    for epoch in range(epochs):
        cur_pref=[]
        cur_dispref=[]
        for i in range(len(pref_set)):
            cur_pref.append(pref_set[i])
            cur_dispref.append(dispref_set[i]) #collects preferred and dispreferred responses
            if (i+1) % bs == 0:
                make_fastas(cur_pref, cur_dispref) #sets up necessary files
                run_mpnn('model-DPO') #scores responses
                optimizer.zero_grad()
                b_ref, nb_ref, b_dpo, nb_dpo = collect_logps(cur_pref) #collects scores
                loss = calc_loss(b_dpo, nb_dpo, b_ref, nb_ref, beta) #computes DPO loss
                print(loss)
                loss.backward()
                optimizer.step()
                print(optimizer)
                torch.save({ #saves updated model for next round of scoring
                        'epoch': epoch+1,
                        'step': i,
                        'num_edges' : 48,
                        'noise_level': 0.2,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        }, "../ProteinMPNN/vanilla_model_weights/model-DPO.pt")
                print(loss)
                cur_pref=[]
                cur_dispref=[]

In short, the scoring of my preferred and dispreferred responses must be done in a separate script, meaning I must save the updated model after each batch to be loaded for the following round of scoring. But as I mentioned, the model weights are not changing, and the scores returned by the reference and target models are always the same. Any help in resolving this issue would be greatly appreciated.

I’ve checked to make sure that the model parameters are initialized correctly, with requires_grad=True. They also have no gradient before training (list(model.parameters())[0].grad = None). I also checked to ensure that I’m not overwriting the updated model weights, or accidentally loading the vanilla weights during scoring. I double checked my loss function, and tried setting the loss and learning rates to arbitrarily high values to force the weights to update. However, no change in scoring occurred. The model parameter gradient after the backward call is still None, and I’m not sure why. As mentioned previously, all model parameters are initialized with requires_grad=True. I also tried setting requires_grad=False for all parameters, but the weights still aren’t changing.

Hi Jonathan,
Please can you share what exactly your collect_logps function looks like? I can try helping you after looking at it.

What we need to make sure is that collect_logps is using the exact model instance that you’re trying to update. I don’t know what your collect_logps looks like but how about passing the model instance as one of the params to it?

Hi Srishti,

Thank you so much for the swift response! I’ve copied my other methods below.

def calc_loss(model_prefered_logprob, model_disprefered_logprob, ref_prefered_logprob, ref_disprefered_logprob, beta=0.5):
    prefered_relative_logprob = [a - b for a, b in zip(model_prefered_logprob, ref_prefered_logprob)]
    print(prefered_relative_logprob)
    disprefered_relative_logprob = [a - b for a, b in zip(model_disprefered_logprob, ref_disprefered_logprob)]
    print(disprefered_relative_logprob)
    new = [beta * (a - b) for a, b in zip(prefered_relative_logprob, disprefered_relative_logprob)]
    print(new)
    loss = torch.mean(F.logsigmoid(torch.Tensor(new)))
    loss.requires_grad = True
    return loss

def make_fastas(pref, dispref):
    ctr=0
    with open('temp_dispref.fa','w') as f:
        for a in dispref:
            f.write('>'+str(ctr)+'\n')
            f.write(a+'\n')
            ctr+=1
    ctr=0
    with open('temp_pref.fa','w') as f:
        for a in pref:
            f.write('>'+str(ctr)+'\n')
            f.write(a+'\n')
            ctr+=1

def run_mpnn(new_model_name):
    subprocess.call(['./mpnn_score.sh','pdb_in/HTRA1_1.pdb','temp_pref.fa','bind_out_ref/','v_48_020'])
    subprocess.call(['./mpnn_score.sh','pdb_in/HTRA1_1.pdb','temp_dispref.fa','nonbind_out_ref/','v_48_020'])
    subprocess.call(['./mpnn_score.sh','pdb_in/HTRA1_1.pdb','temp_pref.fa','bind_out_dpo/',new_model_name])
    subprocess.call(['./mpnn_score.sh','pdb_in/HTRA1_1.pdb','temp_dispref.fa','nonbind_out_dpo/',new_model_name])

def collect_logps(pref):
    bind_scores_ref=[]
    nonbind_scores_ref=[]
    bind_scores_dpo=[]
    nonbind_scores_dpo=[]
    for i in range(len(pref)):
        fbref = np.load('bind_out_ref/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz')
        subprocess.call(['rm','bind_out_ref/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz'])
        
        fnbref = np.load('nonbind_out_ref/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz')
        subprocess.call(['rm','nonbind_out_ref/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz'])
        
        fbdpo = np.load('bind_out_dpo/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz')
        subprocess.call(['rm','bind_out_dpo/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz'])
        
        fnbdpo = np.load('nonbind_out_dpo/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz')
        subprocess.call(['rm','nonbind_out_dpo/score_only/HTRA1_1_fasta_'+str(i+1)+'.npz'])
        
        bind_scores_ref.append(fbref['score'][0])
        nonbind_scores_ref.append(fnbref['score'][0])
        bind_scores_dpo.append(fbdpo['score'][0])
        nonbind_scores_dpo.append(fnbdpo['score'][0])
    return bind_scores_ref, nonbind_scores_ref, bind_scores_dpo, nonbind_scores_dpo

To put it simply, make_fastas prepares the batch set for input into run_mpnn. Run_mpnn uses subprocess calls to run the scoring script and generate the files with the logps. Collect_logps reads through the output files and loads the logps. Those scores are then passed to calc_loss to compute the loss before the backward pass. Of these methods, only run_mpnn requires the model instance. The model is saved separately, and is loaded by the scoring script being called by run_mpnn.

Thanks for sharing the code. The issue is that your model parameters are not a part of the computational graph corresponding to loss. I recommend that you first go through PyTorch’s official blogs on autograd and computational graphs to understand how gradient calculation is done by autograd. Some tutorials shall help, too.

Please also make sure that all operations use PyTorch tensors rather than using Python objects like lists first and then using them to create tensors as you’ve done here:

loss = torch.mean(F.logsigmoid(torch.Tensor(new)))
1 Like

Hi Srishti,

Thank you again for the help! I’ll see if I can use the model parameters directly for the loss calculation.