Batch non-maximum suppression on the GPU

hello there,
using the awesome idea from torchvision “batched_nms”, this following code can decode for several images / several classes at once, it works because batched_nms offsets boxes according to their category, so you never perform a wrong suppression.

I also tried to accelerate box encoding, if you are interested you can have a peek here: https://github.com/etienne87/torch_object_rnn/blob/master/core/anchors.py

num_classes = cls_preds.shape[-1] - self.label_offset
num_anchors = box_preds.shape[1]
boxes = box_preds.unsqueeze(2).expand(-1, num_anchors, num_classes, 4).contiguous()
scores = cls_preds[..., self.label_offset:].contiguous()
boxes = boxes.view(-1, 4)
scores = scores.view(-1)
rows = torch.arange(len(box_preds), dtype=torch.long)[:, None]
cols = torch.arange(num_classes, dtype=torch.long)[None, :]
idxs = rows * num_classes + cols
idxs = idxs.unsqueeze(1).expand(len(box_preds), num_anchors, num_classes)
idxs = idxs.to(scores).view(-1)
mask = scores >= score_thresh
boxesf = boxes[mask].contiguous()
scoresf = scores[mask].contiguous()
idxsf = idxs[mask].contiguous()

keep = batched_nms(boxesf, scoresf, idxsf, nms_thresh)

boxes = boxesf[keep]
scores = scoresf[keep]
labels = idxsf[keep] % num_classes
batch_index = idxsf[keep] // num_classes
2 Likes