I have implemented yolov5m from pseudo-scratch and I am having troubles to debug the loss function. In particular after 50 epochs on coco128 (first 128 images of MS COCO) my net is having a MAP of 0, objectness accuracy of 0% and no-obj accuracy near 100 %.
I checked the code before the loss fn and targets are built correctly (checked by plotting images and the boxes) and therefore I concluded that the issue is in the loss fx.
I created my loss function it by merging some parts of yolov5 ultralytics loss fn and other parts by Aladdin Persson.
The result is the following:
class YOLO_LOSS:
def __init__(self, model):
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
self.lambda_class = 1
self.lambda_noobj = 1
self.lambda_obj = 1
self.lambda_box = 10
anchors = getattr(model.head, "anchors").clone().detach()
self.anchors = torch.cat((anchors[0], anchors[1], anchors[2]), dim=0)
self.na = self.anchors.shape[0]
self.num_anchors_per_scale = self.na // 3
self.S = getattr(model.head, "stride")
self.ignore_iou_thresh = 0.5
self.ph = None # this variable is used in the build_targets method, defined here for readability.
self.pw = None # this variable is used in the build_targets method, defined here for readability.
def __call__(self, preds, targets, pred_size):
# list of lists --> [pred[0].height, pred[0].width, pred[1].height... etc]
targets = [self.build_targets(preds, bboxes, pred_size) for bboxes in targets]
t1 = torch.stack([target[0] for target in targets], dim=0).to(config.DEVICE)
t2 = torch.stack([target[1] for target in targets], dim=0).to(config.DEVICE)
t3 = torch.stack([target[2] for target in targets], dim=0).to(config.DEVICE)
anchors = self.anchors.reshape(3, 3, 2).to(config.DEVICE)
loss = (self.compute_loss(preds[0], t1, anchors=anchors[0])
+ self.compute_loss(preds[1], t2, anchors=anchors[1])
+ self.compute_loss(preds[2], t3, anchors=anchors[2]))
return loss
def build_targets(self, input_tensor, bboxes, pred_size, check_loss=True):
ph = pred_size[0]
pw = pred_size[1]
targets = [
torch.zeros((self.num_anchors_per_scale, input_tensor[i].shape[2], input_tensor[i].shape[3], 6))
for i in range(len(self.S))
]
classes = [box[-1] - 1 for box in bboxes] # classes in coco start from 1
bboxes = [box[:-1] for box in bboxes]
bboxes = rescale_bboxes(bboxes, starting_size=(640, 640), ending_size=(pw, ph))
for idx, box in enumerate(bboxes):
class_label = classes[idx]
box = coco_to_yolo(box, pw, ph)
iou_anchors = iou_width_height(torch.tensor(box[2:4]),
self.anchors / torch.tensor([640, 640]).to(config.DEVICE))
anchor_indices = iou_anchors.argsort(descending=True, dim=0)
x, y, width, height, = box
has_anchor = [False] * 3
for anchor_idx in anchor_indices:
# i.e if the best anchor idx is 8, num_anchors_per_scale
# we know that 8//3 = 2 --> the best scale_idx is 2 -->
# best_anchor belongs to last scale (52,52)
# scale_idx will be used to slice the variable "targets"
# another pov: scale_idx searches the best scale of anchors
scale_idx = torch.div(anchor_idx, self.num_anchors_per_scale, rounding_mode="floor")
# print(scale_idx)
# anchor_on_scale searches the idx of the best anchor in a given scale
# found via index in the line below
anchor_on_scale = anchor_idx % self.num_anchors_per_scale
# slice anchors based on the idx of the best scales of anchors
scale_x = input_tensor[int(scale_idx)].shape[2]
scale_y = input_tensor[int(scale_idx)].shape[3]
i, j = int(scale_y * y), int(scale_x * x) # which cell
anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
if not anchor_taken and not has_anchor[scale_idx]:
targets[scale_idx][anchor_on_scale, i, j, 0] = 1
x_cell, y_cell = scale_x * x - j, scale_y * y - i # both between [0,1]
width_cell, height_cell = (
width * scale_x,
height * scale_y,
) # can be greater than 1 since it's relative to cell
box_coordinates = torch.tensor(
[x_cell, y_cell, width_cell, height_cell]
)
targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
has_anchor[scale_idx] = True
elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction
return targets
# TRAINING_LOSS
def compute_loss(self, preds, targets, anchors):
# originally anchors have shape (3,2) --> 3 set of anchors of width and height
anchors = anchors.reshape(1, 3, 1, 1, 2)
# because of https://github.com/ultralytics/yolov5/issues/471
xy = preds[..., 1:3].sigmoid() * 2 - 0.5
wh = (preds[..., 3:5].sigmoid() * 2) ** 2 * anchors
# Check where obj and noobj (we ignore if target == -1)
obj = targets[..., 0] == 1 # in paper this is Iobj_i
noobj = targets[..., 0] == 0 # in paper this is Inoobj_i
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
# not doing sigmoid because self.bce is bce_with_logits
no_object_loss = self.bce(
# [..., 0:1] instead of [..., 0] to keep shape untouched
(preds[..., 0:1][noobj]), (targets[..., 0:1][noobj]),
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
# https://www.youtube.com/watch?v=Grir6TZbc1M&t=5668s
# explanation of anchors at 12:17:
# dim=-1 means the last dim
box_preds = torch.cat([xy, wh], dim=-1)
ious = intersection_over_union(box_preds, targets[..., 1:5], GIoU=True).detach()
# mse or bce??
object_loss = self.bce(self.sigmoid(preds[..., 0:1][obj]), ious[obj] * targets[..., 0:1][obj])
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
# preds[..., 1:3] = self.sigmoid(preds[..., 1:3]) # x,y coordinates
# targets[..., 3:5] = ((targets[..., 3:5]/2)**(1/2))/anchors
# width, height coordinates
# box_loss = self.mse(preds[..., 1:5][obj], targets[..., 1:5][obj])
box_loss = (1 - ious[obj]).mean()
# ================== #
# FOR CLASS LOSS #
# ================== #
class_loss = self.entropy(
(preds[..., 5:][obj]), (targets[..., 5][obj].long()),
)
return (
self.lambda_box * box_loss
+ self.lambda_obj * object_loss
+ self.lambda_noobj * no_object_loss
+ self.lambda_class * class_loss
)
I know it’s quite a vague topic but I am having quite a hard time figuring out what’s going wrong.