I am trying to implement a loss that would be expressed as L2(full_image) + weight * L2(ROI).
I found the following topic suggesting I should create a mask over the ROI and zero out the loss.
So I did the following:
criterion = torch.nn.MSELoss(reduction="none")
...
for batch_idx, (X, y) in enumerate(train_loader):
optimizer.zero_grad()
X = X.to(device)
y = y.to(device)
out = model(X)
loss = criterion(out, y)
roi_mask = torch.where(y == 0, 0, 1)
roi_loss = loss * roi_mask
mean_roi_loss = torch.sum(roi_loss) / torch.sum(roi_mask)
final_loss = torch.mean(loss) + weight * mean_roi_loss
final_loss.backward()
However, I am facing 2 issues:
- In some cases, my images do not have any ROI, thus mask is empty resulting in
mean_roi_loss
being NaN (becausetorch.sum(roi_mask) == 0
). I thought to enclose all this in aif
statement checking if the mask is empty, and in that case,final_loss = torch.mean(loss)
. But I am not sure whether backward computation would work with this condition, creating different losses depending on the images. I could also replace bymean_roi_loss = torch.sum(roi_loss) / (torch.sum(roi_mask) + 1)
. - Should I compute mean(loss) + weight * mean_over_ROI(loss * mask) or mean(loss + weight * mask * loss) i.e. computing the mean over the whole image, not just the ROI?