Hi,
I am using the following repository YOLO-POSE, I am trying to do batched inference on multiple images at a same time. The model is taking similar time for different batch size (upto 6) but after the model computes the results and I need to perform NMS, the time taken was twice or sometime thrice of the model time.
On Further inspection I found out that only in tthe first iteration of NMS, it is taking 90% of the time of the NMS.
Here is the code : -
def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=True, multi_label=False,
labels=(), kpt_label=False, nc=None, nkpt=None):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if nc is None:
nc = prediction.shape[2] - 5 if not kpt_label else prediction.shape[2] - 56 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
t1 = time.time()
x = x[xc[xi]] # confidence
print("Time taken for NMS cam : " , xi , " x = x[xc[xi]] : " , time.time() - t1, " xc shape : " , xc.shape)
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
t1 = time.time()
x[:, 5:5+nc] *= x[:, 4:5] # conf = obj_conf * cls_conf
print("Time taken for NMS cam : " , xi , " x[:, 5:5+nc] *= x[:, 4:5] : " , time.time() - t1)
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
t1 = time.time()
box = xywh2xyxy(x[:, :4])
print("Time taken for NMS cam : " , xi , " box = xywh2xyxy(x[:, :4]) : " , time.time() - t1)
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
if not kpt_label:
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
else:
t1 = time.time()
kpts = x[:, 6:]
conf, j = x[:, 5:6].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres]
print("Time taken for NMS cam : " , xi , " x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : " , time.time() - t1)
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
t1 = time.time()
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
print("Time taken for NMS cam : " , xi , " c = x[:, 5:6] * (0 if agnostic else max_wh) : " , time.time() - t1)
t1 = time.time()
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
print("Time taken for NMS cam : " , xi , " boxes, scores = x[:, :4] + c, x[:, 4] : " , time.time() - t1)
t1 = time.time()
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
print("Time taken for NMS cam : " , xi , " i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : " , time.time() - t1)
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output
and here is the time taken
Time taken for NMS cam : 0 x = x[xc[xi]] : 0.024521827697753906 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 0 x[:, 5:5+nc] *= x[:, 4:5] : 6.413459777832031e-05
Time taken for NMS cam : 0 box = xywh2xyxy(x[:, :4]) : 0.0001697540283203125
Time taken for NMS cam : 0 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00015687942504882812
Time taken for NMS cam : 0 c = x[:, 5:6] * (0 if agnostic else max_wh) : 2.09808349609375e-05
Time taken for NMS cam : 0 boxes, scores = x[:, :4] + c, x[:, 4] : 1.6927719116210938e-05
Time taken for NMS cam : 0 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00021004676818847656
Time taken for NMS cam : 1 x = x[xc[xi]] : 4.9114227294921875e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 1 x[:, 5:5+nc] *= x[:, 4:5] : 2.9802322387695312e-05
Time taken for NMS cam : 1 box = xywh2xyxy(x[:, :4]) : 0.00013899803161621094
Time taken for NMS cam : 1 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00010418891906738281
Time taken for NMS cam : 1 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.7881393432617188e-05
Time taken for NMS cam : 1 boxes, scores = x[:, :4] + c, x[:, 4] : 1.9073486328125e-05
Time taken for NMS cam : 1 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011444091796875
Time taken for NMS cam : 2 x = x[xc[xi]] : 4.5299530029296875e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 2 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 2 box = xywh2xyxy(x[:, :4]) : 0.0001354217529296875
Time taken for NMS cam : 2 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.965896606445312e-05
Time taken for NMS cam : 2 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.811981201171875e-05
Time taken for NMS cam : 2 boxes, scores = x[:, :4] + c, x[:, 4] : 1.811981201171875e-05
Time taken for NMS cam : 2 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011396408081054688
Time taken for NMS cam : 3 x = x[xc[xi]] : 4.649162292480469e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 3 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 3 box = xywh2xyxy(x[:, :4]) : 0.0001380443572998047
Time taken for NMS cam : 3 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00011682510375976562
Time taken for NMS cam : 3 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 3 boxes, scores = x[:, :4] + c, x[:, 4] : 1.8596649169921875e-05
Time taken for NMS cam : 3 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011110305786132812
Time taken for NMS cam : 4 x = x[xc[xi]] : 4.482269287109375e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 4 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 4 box = xywh2xyxy(x[:, :4]) : 0.00013637542724609375
Time taken for NMS cam : 4 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.822845458984375e-05
Time taken for NMS cam : 4 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 4 boxes, scores = x[:, :4] + c, x[:, 4] : 1.9311904907226562e-05
Time taken for NMS cam : 4 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011277198791503906
Time taken for NMS cam : 5 x = x[xc[xi]] : 4.458427429199219e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 5 x[:, 5:5+nc] *= x[:, 4:5] : 2.8133392333984375e-05
Time taken for NMS cam : 5 box = xywh2xyxy(x[:, :4]) : 0.00013685226440429688
Time taken for NMS cam : 5 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.775161743164062e-05
Time taken for NMS cam : 5 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 5 boxes, scores = x[:, :4] + c, x[:, 4] : 1.811981201171875e-05
Time taken for NMS cam : 5 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00010848045349121094
I realised that it might be due to cuda trying to clear cache for allocation in the first iteration, So I added “torch.cuda.empty_cache()” right after the model output. Like so : -
def forward(self, img, **kwargs):
start = time.time()
with torch.no_grad():
if self.half:
img = img.half().to(self.device)
t1 = time.time()
output = self.model(img, self.output_tensor) # torch.Size([1, 45900, 57])-
print("Model Time Taken : ", time.time() - t1)
torch.cuda.empty_cache()
t2 = time.time()
output = non_max_suppression_kpt(output,
0.8, # Confidence Threshold
0.001, # IoU Threshold
nc=1, # Number of Classes
nkpt=17, # Number of Keypoints
kpt_label=True)
print("NMS Time Taken : ", time.time() - t2)
output, cameras = output_to_keypoint(output)
output = output.reshape(-1,17,3)
conf = output[..., -1]
coords = output[..., :2]
coords, conf = self.transform_keypoints(coords,conf)
return coords, conf, cameras
This fixed the nms time but now the bottleneck was shifted to “torch.cuda.empty_cache()”.
I believe that this is what is causing the issue. When I decided to skip cache using PYTORCH_NO_CUDA_MEMORY_CACHING = 1
, the NMS is not taking the time but now the model becomes the bottlneck as it is now using the memory of GPU directly.
I want to ask How I can fix this? or how I can stop torch from clearing cuda cache in NMS?