How can I change this for loop operation to matrix operation?

I want the model to make predictions with priority information, not just probability, when predicting a true label.

For example, there are 6 classes and I have the priority information like this [0,0,0,1,1,2], threshold: 0.5

Lower number means more important class!

If the difference in probability between the label with more important priority and the current prediction label not exceed the threshold, I change the current prediction label to a label with more important priority.

For example, probability from models [0.3, 0.2, 0.15, 0.15, 0.1, 0.1] and priority [1,0,1,1,1,1], threshold:0.2.

Normally, the predicted class should be 0th, but it is changed to 1th according to priority information.
(1th class is more important class and 0.3-0.2<0.2)

I implemented this logic with the for loop below, but I think it can be optimized with matrix operation. How can I do it?

``````def modify_predResult_with_priority(prob, priority, threshold):
#prob: [batch_size, num_cls]
#priority: [num_cls] ex) 6class: [0,0,0,1,1,2]
#threshold: 0.5

sorted_prob, sorted_prob_cls_index = prob.sort(descending=True)
sorted_prob_priority = torch.tensor(priority)[sorted_prob_cls_index]

candidate_cls = sorted_prob_cls_index[:,0]
candidate_prob = sorted_prob[:,0]
candidate_priority = sorted_prob_priority[:,0]

for batch in range(sorted_prob.shape[0]):
for idx in range(sorted_prob.shape[1]):
if idx==0:
continue
if ((candidate_prob[batch] - sorted_prob[batch][idx]) > threshold):
break
if (sorted_prob_priority[batch][idx] < candidate_priority[batch]):
candidate_prob[batch] = sorted_prob[batch][idx]
candidate_cls[batch] = sorted_prob_cls_index[batch][idx]
candidate_priority[batch] = sorted_prob_priority[batch][idx]
return candidate_cls
``````

Maybe like this?

``````#Get idx which position needs priority update
idxs = prob > threshold

# Update the priority information
priority[idxs] += 1
``````