Hey!
I am working on an evaluation script for my semantic segmentation model and was looking for some IoU implementations.
The first one I found was this one:
EPS = 1e-6
#slightly modified
def get_IoU(outputs, labels):
outputs = outputs.int()
labels = labels.int()
# Taken from: https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy
intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
union = (outputs | labels).float().sum((1, 2)) # Will be zero if both are 0
iou = (intersection + EPS) / (union + EPS) # We smooth our devision to avoid 0/0
# thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 # This is equal to comparing with thresolds
# return thresholded.mean() # Or thresholded.mean() if you are interested in average across the batch
return iou.mean()
The second one was found here (Jaccard Index):
# computes confusion matrix
def _fast_hist(true, pred, num_classes):
mask = (true >= 0) & (true < num_classes)
hist = torch.bincount(
num_classes * true[mask] + pred[mask],
minlength=num_classes ** 2,
).reshape(num_classes, num_classes).float()
return hist
# computes IoU based on confusion matrix
def jaccard_index(hist):
"""Computes the Jaccard index, a.k.a the Intersection over Union (IoU).
Args:
hist: confusion matrix.
Returns:
avg_jacc: the average per-class jaccard index.
"""
A_inter_B = torch.diag(hist)
A = hist.sum(dim=1)
B = hist.sum(dim=0)
jaccard = A_inter_B / (A + B - A_inter_B + EPS)
avg_jacc = nanmean(jaccard) #the mean of jaccard without NaNs
return avg_jacc, jaccard
I wanted to cross check that both compute the same think:
true = torch.tensor([[
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
]])
pred = torch.tensor([[
[1, 0, 0],
[1, 1, 0],
[1, 0, 0],
]])
print(true.shape, pred.shape)
print(get_IoU(pred, true))
hist = _fast_hist(pred, true, num_classes=2)
print(jaccard_index(hist))
result is:
torch.Size([1, 3, 3]) torch.Size([1, 3, 3])
tensor(0.7500)
(tensor(0.7917), tensor([0.8333, 0.7500]))
I have 2 Questions:
- What is the correct IoU result? 0.75 / 0.83 / 0.79?
- Can someone explain to me what the purpose of the EPS (I assume epsilone?) value is and why we need it?
Thanks for any help!