Hi. How do i do non-max-supression
for bounding-boxes with mutiple labels in PyTorch
?
I was hoping for something that can provide the same functionality as this TensorFlow
function : tf.image.combined_non_max_suppression
Thank You.
Hi. How do i do non-max-supression
for bounding-boxes with mutiple labels in PyTorch
?
I was hoping for something that can provide the same functionality as this TensorFlow
function : tf.image.combined_non_max_suppression
Thank You.
Hy @benihime91 Have you checked torchvision.ops.nms
https://pytorch.org/docs/stable/torchvision/ops.html
Hello @Usama_Hasan
Yes I have …
torchvision.ops.nms
performs nms
only for given boxes
& scores
i don’t think it take into account the classes
of the boxes (correct me if I am wrong)…
Yes that’s the case.
I guess this will help you in what you’re looking for.
I was looking around the sources code of detectron-2
where I found this really simple & efficient implementation:
I implemented it like-so:
def batched_nms(boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float):
"""
[
From :
(https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/roi_heads/fast_rcnn.py#L107-L115)
]
Same as torchvision.ops.boxes.batched_nms, but safer.
"""
assert boxes.shape[-1] == 4
if len(boxes) < 40000:
return torchvision.ops.boxes.batched_nms(boxes, scores, idxs, iou_threshold)
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
mask = (idxs == id).nonzero().view(-1)
keep = torchvision.ops.nms(boxes[mask], scores[mask], iou_threshold)
result_mask[mask[keep]] = True
keep = result_mask.nonzero().view(-1)
keep = keep[scores[keep].argsort(descending=True)]
return keep
def NonMaxSupression(predictions, score_thres=0.5, iou_thres=0.5):
predictions[..., :4] = cvt_boxes(predictions[..., :4])
output_preds = []
for i, pred in enumerate(predictions):
pred = pred[pred[:, 4] >= score_thres]
if not pred.size(0): continue
boxes = pred[:,:4]
scores = pred[:, 4] * pred[:, 5:].max(1)[0]
_ ,cls_labels = pred[:, 5:].max(1)
detections = torch.cat((boxes, scores.unsqueeze(dim=1), cls_labels.unsqueeze(dim=1).float()),
dim=-1)
scores, cls_labels = scores.view([-1]), cls_labels.type(torch.int32).view([-1])
assert boxes.size(0) == scores.size(0) == cls_labels.size(0)
keep = batched_nms(boxes, scores, cls_labels, iou_thres)
output_preds.append(detections[keep])
assert len(output_preds) > 0, "[INFO] No objects detected in current Image"
return output_preds
I think this is fine, just not really parallel implementation.