Dilation before IoU

I am trying to do the following:

  1. Do one iteration of dilation on the output of the model (in validation or testing).
  2. Compute the IoU after the dilation for each class.

The reason for the dilation step is because I am willing to tolerate some error for the second class, for example, if the output of the model is a thin line and the ground truth is a bit thicker, normally the IoU will suffer from that but for me a thin line is acceptable. However, for the first class (background) or the third class, I don’t want to do the dilation step.

Currently I do a dilation of the whole output as provided below.

This code works great:

def iou(output, label, n_classes):
    ious = []
    pred = torch.argmax(output, dim=1).squeeze(1).view(-1)
    target = label.view(-1)
    SMOOTH = 1e-6
    for idx in range(n_classes):
        preds_inds = pred == idx
        target_inds = target == idx
        intersection = (preds_inds & target_inds).float().sum()
        union = (preds_inds | target_inds).float().sum()
        iou = (intersection + SMOOTH) / (union + SMOOTH)
        ious.append(iou.item())
    return np.array(ious)

Adding the dilation step:

def iou(output, label, n_classes):
    ious = []
    get_cuda_device = "cuda:0"
    cuda_check = output.is_cuda
    if cuda_check:
        get_cuda_device = output.get_device()
    kernel = np.ones((5, 5), np.uint8)
    pred = torch.argmax(output, dim=1).squeeze(1).view(-1).detach().cpu().numpy() * 255
    pred = cv2.dilate(pred.astype(np.uint8), kernel, iterations=1) / 255
    pred = torch.from_numpy(pred).long().to(get_cuda_device)
    target = label.view(-1)
    SMOOTH = 1e-6
    for idx in range(n_classes):
        preds_inds = pred == idx
        target_inds = target == idx
        intersection = (preds_inds & target_inds).float().sum()
        union = (preds_inds | target_inds).float().sum()
        iou = (intersection + SMOOTH) / (union + SMOOTH)
        ious.append(iou.item())
    return np.array(ious)

Now I get the following error:

intersection = (preds_inds & target_inds).float().sum()
RuntimeError: CUDA out of memory. Tried to allocate 7440.75 GiB (GPU 0; 10.76 GiB total capacity; 463.73 MiB already allocated; 9.33 GiB free; 584.00 MiB reserved in total by PyTorch)

Can someone suggest a fix or even a better way to achieve what I described above (not necessarily using a dilation) .

EDIT: Fixed the problem by replacing

pred = torch.from_numpy(pred).long().to(get_cuda_device)

with

pred = torch.from_numpy(pred).long().squeeze(1).view(-1).to(get_cuda_device)

My question still holds tho, is there a better way to do this?