Training loss with contrastive pairs does not update after epoch

I’m training a pre-trained transformer with contrastive pairs using nn.MSELoss(). While loss is update after each batch, it does not change after each epoch.

def train_model():

    dataset = read_train_and_val_data()

    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/stsb-mpnet-base-v2',do_lower_case=True)

    dataset = dataset.map(
        lambda x: tokenizer(
                x['lhs'], max_length=128, padding='max_length',
                truncation=True, return_tensors='pt'
            ), batched=True, batch_size=2000, num_proc=32
    )
    

    dataset = dataset.rename_column('input_ids', 'lhs_ids')
    dataset = dataset.rename_column('attention_mask', 'lhs_mask')

    dataset = dataset.map(
        lambda x: tokenizer(
                x['rhs'], max_length=128, padding='max_length',
                truncation=True, return_tensors='pt'
        ), batched=True, batch_size=2000, num_proc=32
    )

    dataset = dataset.rename_column('input_ids', 'rhs_ids') 
    dataset = dataset.rename_column('attention_mask', 'rhs_mask')
    dataset = dataset.remove_columns(['rhs', 'lhs'])

    # Initialize dataloader
    dataset.set_format(type='torch', output_all_columns=True)

    #save dataset to pickle file
    with open("train_dataset.pk", "wb") as train_file:
        pickle.dump(dataset, train_file)

    with open("train_dataset.pk", "rb") as train_file:
        dataset = pickle.load(train_file)


    batch_size = 128

    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=16)
    # model = nn.DataParallel( BertModel.from_pretrained('bert-base-uncased'))
    model = AutoModel.from_pretrained('sentence-transformers/stsb-mpnet-base-v2')
    model = nn.DataParallel(model)
    model.to(device)

    cos_sim = torch.nn.CosineSimilarity()

    loss_func = nn.MSELoss()
    # move layers to device

    # initialize Adam optimizer
    optim = torch.optim.Adam(model.parameters(), lr=2e-5)

    # setup warmup for first ~10% of steps
    total_steps = int(len(dataset)/batch_size)
    warmup_steps = int(0.1 * total_steps)
    scheduler = get_linear_schedule_with_warmup(
        optim, num_warmup_steps=warmup_steps,
        num_training_steps=total_steps-warmup_steps
    )


    from tqdm.auto import tqdm
    epochs = 8
   

    for epoch in range(epochs):
        losses = []
        model.train()
        output_dir = '/home/ubuntu/research/research_sandbox/wanjiru/nlp/gro_source_mapping/gro_nlp/src/models/torch_models/model_epoch_{}/'.format(epoch)
        loop = tqdm(loader, leave=True)
        for batch in loop:
            # batch.to(accelerator.device)
            torch.autograd.set_detect_anomaly = True
            # zero all gradients on each new step
            optim.zero_grad()
            # prepare batches and more all to the active device
            anchor_ids = batch['lhs_ids'].to(device)
            anchor_mask = batch['lhs_mask'].to(device)
            pos_ids = batch['rhs_ids'].to(device)
            pos_mask = batch['lhs_mask'].to(device)
            #forward pass
            a = model(
                anchor_ids, attention_mask=anchor_mask
            )
            p = model(
                pos_ids, attention_mask=pos_mask
            )
            a = mean_pool(a, anchor_mask)
            p = mean_pool(p, pos_mask)

            scores = cos_sim(a, p)
            scores.to(device)
            
            labels = batch['score'].to(device)

            loss = loss_func(scores, labels)
            losses.append(loss.item())
            
            # backward pass
            loss.backward()
            optim.step()
            # update learning rate scheduler
            scheduler.step()
            # update the TDQM progress bar
            loop.set_description(f'Epoch {epoch}')
            loop.set_postfix(loss=loss.item())
        
        print(losses)
            # save model after each epoch
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(output_dir, CONFIG_NAME)

        #model_to_save = model.module if hasattr(model, 'module') else model
        model_to_save = model.module
        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(output_dir)

        # evaluate model after each epoch:
        run_eval(output_dir, epoch)

train_model()```