How can I change this for loop operation to matrix operation?

This task is about classification.

I want the model to make predictions with priority information, not just probability, when predicting a true label.

For example, there are 6 classes and I have the priority information like this [0,0,0,1,1,2], threshold: 0.5

Lower number means more important class!

If the difference in probability between the label with more important priority and the current prediction label not exceed the threshold, I change the current prediction label to a label with more important priority.

For example, probability from models [0.3, 0.2, 0.15, 0.15, 0.1, 0.1] and priority [1,0,1,1,1,1], threshold:0.2.

Normally, the predicted class should be 0th, but it is changed to 1th according to priority information.
(1th class is more important class and 0.3-0.2<0.2)

I implemented this logic with the for loop below, but I think it can be optimized with matrix operation. How can I do it?

def modify_predResult_with_priority(prob, priority, threshold):
    #prob: [batch_size, num_cls]
    #priority: [num_cls] ex) 6class: [0,0,0,1,1,2]
    #threshold: 0.5
   
    sorted_prob, sorted_prob_cls_index = prob.sort(descending=True)
    sorted_prob_priority = torch.tensor(priority)[sorted_prob_cls_index]

    candidate_cls = sorted_prob_cls_index[:,0]
    candidate_prob = sorted_prob[:,0]
    candidate_priority = sorted_prob_priority[:,0]

    for batch in range(sorted_prob.shape[0]):
        for idx in range(sorted_prob.shape[1]):
            if idx==0:
                continue
            if ((candidate_prob[batch] - sorted_prob[batch][idx]) > threshold):
                break
            if (sorted_prob_priority[batch][idx] < candidate_priority[batch]):
                candidate_prob[batch] = sorted_prob[batch][idx]
                candidate_cls[batch] = sorted_prob_cls_index[batch][idx]
                candidate_priority[batch] = sorted_prob_priority[batch][idx]
    return candidate_cls

Maybe like this?

#Get idx which position needs priority update
idxs = prob > threshold

# Update the priority information
priority[idxs] += 1