Computing loss over ROI

I am trying to implement a loss that would be expressed as L2(full_image) + weight * L2(ROI).
I found the following topic suggesting I should create a mask over the ROI and zero out the loss.
So I did the following:

criterion = torch.nn.MSELoss(reduction="none")
...
for batch_idx, (X, y) in enumerate(train_loader):
    optimizer.zero_grad()
    X = X.to(device)
    y = y.to(device)
    out = model(X)
    loss = criterion(out, y)
    roi_mask = torch.where(y == 0, 0, 1)
    roi_loss = loss * roi_mask
    mean_roi_loss = torch.sum(roi_loss) / torch.sum(roi_mask)
    final_loss = torch.mean(loss) + weight * mean_roi_loss
    final_loss.backward()

However, I am facing 2 issues:

  • In some cases, my images do not have any ROI, thus mask is empty resulting in mean_roi_loss being NaN (because torch.sum(roi_mask) == 0). I thought to enclose all this in a if statement checking if the mask is empty, and in that case, final_loss = torch.mean(loss). But I am not sure whether backward computation would work with this condition, creating different losses depending on the images. I could also replace by mean_roi_loss = torch.sum(roi_loss) / (torch.sum(roi_mask) + 1).
  • Should I compute mean(loss) + weight * mean_over_ROI(loss * mask) or mean(loss + weight * mask * loss) i.e. computing the mean over the whole image, not just the ROI?