Non max suppression using pyTorch

I am not sure if this has been answered before, but the libraries of FasterRCNN performs the non max suppression using CUDA kernel. I was hoping if there is a way to code it using pure PyTorch and no CUDA kernels. I can not use the ones from torchvision since I am going to work on 3d boxes while the ones from vision library are for 2d. The pure python based code I could find was

def nms(dets, scores, thresh):
    '''
    dets is a numpy array : num_dets, 6
    scores ia  nump array : num_dets,
    '''
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    z1 = dets[:, 2]
    x2 = dets[:, 3]
    y2 = dets[:, 4]
    z2 = dets[:, 5]

    volume = (x2 - x1 + 1) * (y2 - y1 + 1) * (z2 - z1 + 1)
    order = scores.argsort()[::-1]  # get boxes with more ious first

    keep = []
    while order.size > 0:
        i = order[0]  # pick maxmum iou box
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        zz1 = np.maximum(z1[i], z1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
        zz2 = np.minimum(z2[i], z2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)  # maximum width
        h = np.maximum(0.0, yy2 - yy1 + 1)  # maxiumum height
        l = np.maximum(0.0, zz2 - zz1 + 1)  # maxiumum length
        inter = w * h * l
        ovr = inter / (volume[i] + volume[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

I would really appreciate if there is some way to do this using pure PyTorch and perhaps get rid of the For Loop if possible

While a custom CUDA kernel could potentially yield a speedup for the operations, you could start with a pure PyTorch implementation and could thus also use the GPU.

The numpy methods should also be available in PyTorch and you could just rewrite the code.
If you see a bottleneck in this translation from numpy to PyTorch, you would need to dig through the torchvision kernel and try to adapt it to your use case.

Also, double post from here.

@ptrblck I have tried implementing the same function. Since I have your attention here, it would be nice if you can take a look and let me know if I have made the correct translation of the code. I will also remove the second post since a reply here clarifies that automatically.

def nms(dets, thresh):
    '''
    dets is a numpy array : num_dets, 6
    The detections are already in sorted order and so can be used directly.
    '''
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    z1 = dets[:, 2]
    x2 = dets[:, 3]
    y2 = dets[:, 4]
    z2 = dets[:, 5]

    volume = (x2 - x1 + 1) * (y2 - y1 + 1) * (z2 - z1 + 1)
    order = torch.arange(dets.size(0))  # The boxes are in sorted order

    keep = []
    while order.size(0) > 0:
        i = order[0]  # pick maxmum iou box
        keep.append(i)
        xx1 = torch.max(x1[i], x1[order[1:]])
        yy1 = torch.max(y1[i], y1[order[1:]])
        zz1 = torch.max(z1[i], z1[order[1:]])
        xx2 = torch.max(x2[i], x2[order[1:]])
        yy2 = torch.max(y2[i], y2[order[1:]])
        zz2 = torch.max(z2[i], z2[order[1:]])

        w = torch.max(torch.as_tensor(0.0), xx2 - xx1 + 1)  # maximum width
        h = torch.max(torch.as_tensor(0.0), yy2 - yy1 + 1)  # maxiumum height
        l = torch.max(torch.as_tensor(0.0), zz2 - zz1 + 1)  # maxiumum length
        inter = w * h * l
        ovr = inter / (volume[i] + volume[order[1:]] - inter)

        inds = torch.nonzero(ovr <= thresh)
        # We basically start from the first index and hence an offset has to be added
        # So we keep track of indices which are less than threshold and we keep filtering it away
        order = order[inds + 1]

    return keep

The torch.nonzero() operation should probably be replaced with torch.where, as you are using np.where in your original code.
Besides that it looks alright, but I would recommend to verify it with some dummy tensors.

Just create random tensors in the expected shape, and compare the outputs of the new PyTorch method to the original numpy method (by passing them as tensor.numpy() into the function).

@ptrblck Thank you so much for your feedback. I managed to reduce the function size further.

def nms_pytorch(dets, thresh):
    '''
    dets is a numpy array : num_dets, 6
    The detections are already in sorted order and so can be used directly.
    '''

    keep = []
    selected = torch.ones(dets.size(0))
    for i, bbox in enumerate(dets):
        if selected[i] == 0:
            continue
        keep.append(i)
        ovr = bbox_overlap(bbox.unsqueeze(0), dets[i+1:])
        # ovr is 1 X N tensor and we need to check for columns where it is greater than thresh
        inds = torch.where(ovr > thresh)[1]
        # We basically start from the first index and hence an offset has to be added
        # So we keep track of indices which are less than threshold and we keep filtering it away
        selected[(inds + i) + 1] = 0

    return keep

Is there a way I can use to compare the performance differences from the GPU version?

Thanks again for your time

You could run a quick benchmark using this pseudo-code:

nb_iters = 100

# warmup
for _ in range(10):
    my_fun()

torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    my_fun()
torch.cuda.synchronize()
t1 = time.time()
print((t1 - t0)/nb_iters)

where my_fun would be your current function(s) and the one(s) you would like to compare it with.

1 Like