Non-max Supression for boxes with multiple classes

Hi. How do i do non-max-supression for bounding-boxes with mutiple labels in PyTorch?
I was hoping for something that can provide the same functionality as this TensorFlow function : tf.image.combined_non_max_suppression

Thank You.

Hy @benihime91 Have you checked torchvision.ops.nms
https://pytorch.org/docs/stable/torchvision/ops.html

Hello @Usama_Hasan
Yes I have …
torchvision.ops.nms performs nms only for given boxes & scores i don’t think it take into account the classes of the boxes (correct me if I am wrong)…

Yes that’s the case.
I guess this will help you in what you’re looking for.

@Usama_Hasan

I was looking around the sources code of detectron-2 where I found this really simple & efficient implementation:
I implemented it like-so:

def batched_nms(boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float):
    """
    [
    From :
    (https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/roi_heads/fast_rcnn.py#L107-L115)
    ]
    Same as torchvision.ops.boxes.batched_nms, but safer.
    """
    assert boxes.shape[-1] == 4
    if len(boxes) < 40000:
        return torchvision.ops.boxes.batched_nms(boxes, scores, idxs, iou_threshold)

    result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
    for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
        mask = (idxs == id).nonzero().view(-1)
        keep = torchvision.ops.nms(boxes[mask], scores[mask], iou_threshold)
        result_mask[mask[keep]] = True
    keep = result_mask.nonzero().view(-1)
    keep = keep[scores[keep].argsort(descending=True)]
    return keep
def NonMaxSupression(predictions, score_thres=0.5, iou_thres=0.5):
    predictions[..., :4] = cvt_boxes(predictions[..., :4])
    output_preds = []
    
    for i, pred in enumerate(predictions):
        pred = pred[pred[:, 4] >= score_thres]
        if not pred.size(0):   continue

        boxes = pred[:,:4]
        scores = pred[:, 4] * pred[:, 5:].max(1)[0] 
        _ ,cls_labels = pred[:, 5:].max(1)

        detections = torch.cat((boxes, scores.unsqueeze(dim=1), cls_labels.unsqueeze(dim=1).float()),
                               dim=-1)
        
        scores, cls_labels = scores.view([-1]), cls_labels.type(torch.int32).view([-1])
        
        assert boxes.size(0) == scores.size(0) == cls_labels.size(0)
        
        keep = batched_nms(boxes, scores, cls_labels, iou_thres)
        output_preds.append(detections[keep])
        
    assert len(output_preds) > 0, "[INFO] No objects detected in current Image"
    
    return output_preds

I think this is fine, just not really parallel implementation.