Speed up the inference with ensemble technique

So, I have to models which are individually giving decent result. But for post processing ensemble technique is requiring huge computation power as well as inference time is also huge. Both models are heavy though i am thinking if i am missing something.

import torchvision.transforms.functional as F

def ensemble(data, models):
    with torch.no_grad():
        encodings = {}
        cache = {}
        for i in range(len(data)):
            cache[i] = {
#                 'image' : None,
                'classes' : [[] for _ in range(len(thing_classes))],
                'scores' : [[] for _ in range(len(thing_classes))],
                'bboxes' : [[] for _ in range(len(thing_classes))],
                'masks' : [[] for _ in range(len(thing_classes))]
            }
            encodings[i] = [[] for _ in range(len(thing_classes))]
        for i, model in enumerate(models):
            outputs = model(data)
            
            for idx, output in enumerate(outputs):
#                 cache[idx]['image'] = F.resize(data[idx]['image'], (data[idx]['height'], data[idx]['width']))
                
                output = output['instances']
                take = output.scores.cpu().numpy() >= ACCEPTANCE_THRESHOLD
                pred_classes = output.pred_classes.cpu().numpy()[take]

                for cat in range(len(thing_classes)):
                    cache[idx]['classes'][cat].extend(output.pred_classes[take][pred_classes == cat].cpu().numpy().tolist())
                    cache[idx]['scores'][cat].extend(output.scores[take][pred_classes == cat].cpu().numpy().tolist())
                    cache[idx]['bboxes'][cat].extend(output.pred_boxes[take][pred_classes == cat].tensor.cpu().numpy().tolist())
                    cache[idx]['masks'][cat].extend(output.pred_masks[take][pred_classes == cat].cpu().numpy())
            
        for item in cache:
#             image = cache[item]['image']
            for cat in range(len(thing_classes)):
                classes = cache[item]['classes'][cat]
                scores = cache[item]['scores'][cat]
                bboxes = cache[item]['bboxes'][cat]
                masks = cache[item]['masks'][cat]
                assert len(classes) == len(masks) , 'ensemble lenght mismatch'
                if len(classes) > 1:
                    try:
#                         print("BEFORE")
#                         print(image.shape, masks[0].shape)
#                         plt.figure()
#                         plt.imshow(image.permute((1,2,0)))
#                         plt.imshow(np.sum(masks, axis=0), alpha=0.3)
#                         plt.show()

                        classes, scores, masks = nms_predictions(
                            classes, 
                            scores, 
                            bboxes,
                            masks, shape=(masks[0].shape[0],masks[0].shape[1])
                        )

#                         print("BEFORE2")
#                         print(image.shape, masks[0].shape)
#                         plt.figure()
#                         plt.imshow(image.permute((1,2,0)))
#                         plt.imshow(np.sum(masks, axis=0), alpha=0.3)
#                         plt.show()

                        encoded_masks = ensemble_pred_masks(masks, classes, min_pixels=[75, 75, 75, 75], shape=(masks[0].shape[0],masks[0].shape[1]))

#                         print(encoded_masks.shape)
#                         print("AFTER")
#                         plt.figure()
#                         plt.imshow(image.permute((1,2,0)))
#                         plt.imshow(encoded_masks, alpha=0.5)
#                         plt.show()

                        encodings[item][cat].append(rle_encode(encoded_masks))
                    except:
                        print("Error", len(masks))
                        encodings[item][cat].append(" ")
                else:
                    encodings[item][cat].append(" ")
                    
                del masks
                del scores
                # del encoded_masks
                gc.collect()
                
        del cache
        gc.collect()
    return encodings

Can I speed up the inference in changing any of the above stages?