I am trying to do the following:
- Do one iteration of dilation on the output of the model (in validation or testing).
- Compute the IoU after the dilation for each class.
The reason for the dilation step is because I am willing to tolerate some error for the second class, for example, if the output of the model is a thin line and the ground truth is a bit thicker, normally the IoU will suffer from that but for me a thin line is acceptable. However, for the first class (background) or the third class, I don’t want to do the dilation step.
Currently I do a dilation of the whole output as provided below.
This code works great:
def iou(output, label, n_classes):
ious = []
pred = torch.argmax(output, dim=1).squeeze(1).view(-1)
target = label.view(-1)
SMOOTH = 1e-6
for idx in range(n_classes):
preds_inds = pred == idx
target_inds = target == idx
intersection = (preds_inds & target_inds).float().sum()
union = (preds_inds | target_inds).float().sum()
iou = (intersection + SMOOTH) / (union + SMOOTH)
ious.append(iou.item())
return np.array(ious)
Adding the dilation step:
def iou(output, label, n_classes):
ious = []
get_cuda_device = "cuda:0"
cuda_check = output.is_cuda
if cuda_check:
get_cuda_device = output.get_device()
kernel = np.ones((5, 5), np.uint8)
pred = torch.argmax(output, dim=1).squeeze(1).view(-1).detach().cpu().numpy() * 255
pred = cv2.dilate(pred.astype(np.uint8), kernel, iterations=1) / 255
pred = torch.from_numpy(pred).long().to(get_cuda_device)
target = label.view(-1)
SMOOTH = 1e-6
for idx in range(n_classes):
preds_inds = pred == idx
target_inds = target == idx
intersection = (preds_inds & target_inds).float().sum()
union = (preds_inds | target_inds).float().sum()
iou = (intersection + SMOOTH) / (union + SMOOTH)
ious.append(iou.item())
return np.array(ious)
Now I get the following error:
intersection = (preds_inds & target_inds).float().sum()
RuntimeError: CUDA out of memory. Tried to allocate 7440.75 GiB (GPU 0; 10.76 GiB total capacity; 463.73 MiB already allocated; 9.33 GiB free; 584.00 MiB reserved in total by PyTorch)
Can someone suggest a fix or even a better way to achieve what I described above (not necessarily using a dilation) .
EDIT: Fixed the problem by replacing
pred = torch.from_numpy(pred).long().to(get_cuda_device)
with
pred = torch.from_numpy(pred).long().squeeze(1).view(-1).to(get_cuda_device)
My question still holds tho, is there a better way to do this?