Use model predictions on a heldout dataset in custom loss function

I want to design a custom loss function L = L_primary + (alpha)*(new_loss).

Besides, the primary torch.nn.CrossEntropyLoss() component, I want to add another new_loss component that utilizes the model prediction on a held-out dataset that is different from the training dataset used to compute torch.nn.CrossEntropyLoss().

While designing the new_loss component, I am not sure if I should use with torch.set_grad_enabled(False): to predict the output on the held-out dataset.

Here is a code snippet of my training:

for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):

    features = features.to(DEVICE)
    targets = targets.to(DEVICE)
        
    ### FORWARD AND BACK PROP
    logits = model(features)
    probas = F.softmax(logits, dim=1)
    
    L_primary = cost_fn(logits, targets)
  
     model.eval()
    new_loss = NEW_LOSS(model, data_loader)  
     model.train()

    loss = L_primary + (alpha)*new_loss
        
    optimizer.zero_grad()
    
    cost.backward()
    
    ### UPDATE MODEL PARAMETERS
    optimizer.step()

I want to backpropagate on sum of both the losses to update the model parameters. Please let me know if the following is correct to train the model but not run into memory issues while storing outputs.

def NEW_LOSS(model, data_loader):

### Initialize the prediction and label lists(tensors)
predlist=torch.zeros(0,dtype=torch.float32, device=DEVICE)

with torch.set_grad_enabled(False):
    for i, (features, targets) in enumerate(data_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        logits = model(features)
        probas = F.softmax(logits, dim=1)
    
        # Append batch prediction results
        predlist=torch.cat([predlist,probas[:,0].view(-1).type(torch.float32)])

### A function that computes loss using predlist
new_loss = Compute_Loss(predlist)
return new_loss

If I don’t use with torch.set_grad_enabled(False):, then I am running into memory issues. Also, any other suggestions to address my problem are welcome. Help is greatly appreciated. Thank you.