Pytorch taking too much time to clear cuda cache

Hi,

I am using the following repository YOLO-POSE, I am trying to do batched inference on multiple images at a same time. The model is taking similar time for different batch size (upto 6) but after the model computes the results and I need to perform NMS, the time taken was twice or sometime thrice of the model time.

On Further inspection I found out that only in tthe first iteration of NMS, it is taking 90% of the time of the NMS.

Here is the code : -

def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=True, multi_label=False,
                        labels=(), kpt_label=False, nc=None, nkpt=None):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """
    if nc is None:
        nc = prediction.shape[2] - 5  if not kpt_label else prediction.shape[2] - 56 # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_det = 300  # maximum number of detections per image
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        
        t1 = time.time()
        x = x[xc[xi]]  # confidence
        print("Time taken for NMS cam : " , xi , " x = x[xc[xi]] : " , time.time() - t1, " xc shape : " , xc.shape)
        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        t1 = time.time()
        x[:, 5:5+nc] *= x[:, 4:5]  # conf = obj_conf * cls_conf
        print("Time taken for NMS cam : " , xi , " x[:, 5:5+nc] *= x[:, 4:5] : " , time.time() - t1)

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        t1 = time.time()
        box = xywh2xyxy(x[:, :4])
        print("Time taken for NMS cam : " , xi , " box = xywh2xyxy(x[:, :4]) : " , time.time() - t1)

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            if not kpt_label:
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
            else:
                t1 = time.time()
                kpts = x[:, 6:]
                conf, j = x[:, 5:6].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres]
                print("Time taken for NMS cam : " , xi , " x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : " , time.time() - t1)


        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        t1 = time.time()
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        print("Time taken for NMS cam : " , xi , " c = x[:, 5:6] * (0 if agnostic else max_wh) : " , time.time() - t1)
        t1 = time.time()
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        print("Time taken for NMS cam : " , xi , " boxes, scores = x[:, :4] + c, x[:, 4] : " , time.time() - t1)
        t1 = time.time()
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        print("Time taken for NMS cam : " , xi , " i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS : " , time.time() - t1)
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break  # time limit exceeded

    return output

and here is the time taken

Time taken for NMS cam :  0  x = x[xc[xi]] :  0.024521827697753906  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  0  x[:, 5:5+nc] *= x[:, 4:5] :  6.413459777832031e-05
Time taken for NMS cam :  0  box = xywh2xyxy(x[:, :4]) :  0.0001697540283203125
Time taken for NMS cam :  0  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00015687942504882812
Time taken for NMS cam :  0  c = x[:, 5:6] * (0 if agnostic else max_wh) :  2.09808349609375e-05
Time taken for NMS cam :  0  boxes, scores = x[:, :4] + c, x[:, 4] :  1.6927719116210938e-05
Time taken for NMS cam :  0  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00021004676818847656
Time taken for NMS cam :  1  x = x[xc[xi]] :  4.9114227294921875e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  1  x[:, 5:5+nc] *= x[:, 4:5] :  2.9802322387695312e-05
Time taken for NMS cam :  1  box = xywh2xyxy(x[:, :4]) :  0.00013899803161621094
Time taken for NMS cam :  1  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00010418891906738281
Time taken for NMS cam :  1  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.7881393432617188e-05
Time taken for NMS cam :  1  boxes, scores = x[:, :4] + c, x[:, 4] :  1.9073486328125e-05
Time taken for NMS cam :  1  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011444091796875
Time taken for NMS cam :  2  x = x[xc[xi]] :  4.5299530029296875e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  2  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  2  box = xywh2xyxy(x[:, :4]) :  0.0001354217529296875
Time taken for NMS cam :  2  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.965896606445312e-05
Time taken for NMS cam :  2  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.811981201171875e-05
Time taken for NMS cam :  2  boxes, scores = x[:, :4] + c, x[:, 4] :  1.811981201171875e-05
Time taken for NMS cam :  2  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011396408081054688
Time taken for NMS cam :  3  x = x[xc[xi]] :  4.649162292480469e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  3  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  3  box = xywh2xyxy(x[:, :4]) :  0.0001380443572998047
Time taken for NMS cam :  3  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00011682510375976562
Time taken for NMS cam :  3  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  3  boxes, scores = x[:, :4] + c, x[:, 4] :  1.8596649169921875e-05
Time taken for NMS cam :  3  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011110305786132812
Time taken for NMS cam :  4  x = x[xc[xi]] :  4.482269287109375e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  4  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  4  box = xywh2xyxy(x[:, :4]) :  0.00013637542724609375
Time taken for NMS cam :  4  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.822845458984375e-05
Time taken for NMS cam :  4  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  4  boxes, scores = x[:, :4] + c, x[:, 4] :  1.9311904907226562e-05
Time taken for NMS cam :  4  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011277198791503906
Time taken for NMS cam :  5  x = x[xc[xi]] :  4.458427429199219e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  5  x[:, 5:5+nc] *= x[:, 4:5] :  2.8133392333984375e-05
Time taken for NMS cam :  5  box = xywh2xyxy(x[:, :4]) :  0.00013685226440429688
Time taken for NMS cam :  5  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.775161743164062e-05
Time taken for NMS cam :  5  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  5  boxes, scores = x[:, :4] + c, x[:, 4] :  1.811981201171875e-05
Time taken for NMS cam :  5  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00010848045349121094

I realised that it might be due to cuda trying to clear cache for allocation in the first iteration, So I added “torch.cuda.empty_cache()” right after the model output. Like so : -

def forward(self, img, **kwargs):
        start = time.time()
        with torch.no_grad():
            if self.half:
                img = img.half().to(self.device)
            t1 = time.time()
            output = self.model(img, self.output_tensor) # torch.Size([1, 45900, 57])-
            print("Model Time Taken : ", time.time() - t1)
            torch.cuda.empty_cache()
            t2 = time.time()
            output = non_max_suppression_kpt(output, 
                                        0.8, # Confidence Threshold
                                        0.001, # IoU Threshold
                                        nc=1, # Number of Classes
                                        nkpt=17, # Number of Keypoints
                                        kpt_label=True)
            print("NMS Time Taken : ", time.time() - t2)
            output, cameras = output_to_keypoint(output)
        output = output.reshape(-1,17,3)
        conf = output[..., -1]
        coords = output[..., :2]
        coords, conf = self.transform_keypoints(coords,conf)
        return coords, conf, cameras

This fixed the nms time but now the bottleneck was shifted to “torch.cuda.empty_cache()”.

I believe that this is what is causing the issue. When I decided to skip cache using PYTORCH_NO_CUDA_MEMORY_CACHING = 1 , the NMS is not taking the time but now the model becomes the bottlneck as it is now using the memory of GPU directly.

I want to ask How I can fix this? or how I can stop torch from clearing cuda cache in NMS?

@ptrblck I would be greatful for your help here.

CUDA operations are executed asynchronously so you would need to synchronize the code via torch.cuda.synchronize() before starting and stopping the host timers. Otherwise the next blocking operation will accumulate the times of already launched (and still running) kernels.
This also explains why the bottleneck “moves” to the next blocking operation (empty_cache() in your case).

Thanks for the quick response @ptrblck . That does make sense as doing torch.cuda.synchronize() shifts the bottleneck to that. Is there a way i can bypass this? is doing batched inference the best way to process multiple images at once? or should I use torch multiprocessing (I assume that is what torch is using in the backend as well for processing batch dimension).

Will shifting to another format (TensorRT etc) will make work? Would really love if You can guide me to things to explore to overcome this problem.

No, since torch.cuda.synchronize() will just wait for the GPU until it finishes its workload.

Batched inputs will saturate the GPU better. PyTorch does not use multiprocessing in its backend and modules accept batched inputs by default.
In case you are seeing a large CPU overhead in your profile for tiny input you could use CUDA Graphs to reduce this overhead.

You could certainly try to export it to TRT and compare its performance against native PyTorch.