My Loss function becomes 0 in 2nd epoch

I am working on Yolo v8 and passing nuscenes dataset by converting it to Bird’s eye view and then passing it through Yolo. I am using both camera and radar. I will be sharing my head ,Loss function and main file that I am running as well as the output of gt classes and boxes I dont know whats going wrong:
HEAD:
class Detect(nn.Module):
dynamic = False # force grid reconstruction
export = False # export mode
format = None # export format
end2end = False # end2end
max_det = 300 # max_det
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
legacy = False # backward compatibility for v3/v5/v8/v9 models
xyxy = False # xyxy or xywh output

def __init__(self, nc: int = 80, ch: Tuple = ()):
    """
    Initialize the YOLO detection layer with specified number of classes and channels.

    Args:
        nc (int): Number of classes.
        ch (tuple): Tuple of channel sizes from backbone feature maps.
    """
    super().__init__()
    self.tasks = [  # define the tasks
        dict(num_class=1, class_names=['car']),
        dict(num_class=2, class_names=['truck', 'construction_vehicle']),
        dict(num_class=2, class_names=['bus', 'trailer']),
        dict(num_class=1, class_names=['barrier']),
        dict(num_class=2, class_names=['motorcycle', 'bicycle']),
        dict(num_class=2, class_names=['pedestrian', 'traffic_cone']),
    ]

    self.num_classes = sum(t['num_class'] for t in self.tasks)
    self.nc = self.num_classes
    self.num_heads = len(self.tasks)
    self.nl = len(ch)  # number of detection layers
    self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
    self.num_reg_targets = 9  # x, y, z, w, l, h, sinθ, cosθ, vx, vy
    self.no = nc + self.reg_max * self.num_reg_targets  # updated number of outputs
    #self.no = nc + self.reg_max * 4  # number of outputs per anchor
    self.stride = torch.zeros(self.nl)  # strides computed during build
    c2, c3 = max((16, ch[0] // 4, self.reg_max * 9)), max(ch[0], min(self.nc, 100))  # channels
    #self.cv2 = nn.ModuleList(
    #   nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
    #)
    self.cv2 = nn.ModuleList(
        nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, self.num_reg_targets * self.reg_max, 1)) 
    for x in ch
    )
    self.task_heads = nn.ModuleList([
        nn.Sequential(
            DWConv(x, x, 3), Conv(x, c3, 1),
            DWConv(c3, c3, 3), Conv(c3, c3, 1),
            nn.Conv2d(c3, task['num_class'], 1)
        )
        for task in self.tasks
        for x in ch  # assuming one task per level; else you'll need nesting
    ])
    '''
    self.cv3 = (
        nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        if self.legacy
        else nn.ModuleList(
            nn.Sequential(
                nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
                nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
                nn.Conv2d(c3, self.nc, 1),
            )
            for x in ch
        )
    )
    '''
    self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    if self.end2end:
        self.one2one_cv2 = copy.deepcopy(self.cv2)
        self.one2one_cv3 = copy.deepcopy(self.cv3)

def forward(self, x: List[torch.Tensor]) -> Union[List[torch.Tensor], Tuple]:
    """Concatenate and return predicted bounding boxes and class probabilities."""
    if self.end2end:
        return self.forward_end2end(x)
    for i in range(self.nl):
        reg = self.cv2[i](x[i])  # regression
        cls_list = []
        for j in range(self.num_heads):
            cls_out = self.task_heads[j * self.nl + i](x[i])
            cls_list.append(cls_out)
        cls = torch.cat(cls_list, dim=1)
        x[i] = torch.cat((reg, cls), dim=1)

    if self.training:  # Training path
        return x
    y = self._inference(x)
    return y if self.export else (y, x)
def forward_end2end(self, x: List[torch.Tensor]) -> Union[dict, Tuple]:
    x_detach = [xi.detach() for xi in x]

    # One2One path (detached features)
    one2one = []
    for i in range(self.nl):
        reg = self.one2one_cv2[i](x_detach[i])
        cls_list = []
        for j in range(self.num_heads):
            cls_out = self.task_heads[j * self.nl + i](x_detach[i])
            cls_list.append(cls_out)
        cls = torch.cat(cls_list, dim=1)
        one2one.append(torch.cat((reg, cls), dim=1))

    # One2Many path (normal forward)
    for i in range(self.nl):
        reg = self.cv2[i](x[i])
        cls_list = []
        for j in range(self.num_heads):
            cls_out = self.task_heads[j * self.nl + i](x[i])
            cls_list.append(cls_out)
        cls = torch.cat(cls_list, dim=1)
        x[i] = torch.cat((reg, cls), dim=1)

    if self.training:
        return {"one2many": x, "one2one": one2one}

    y = self._inference(one2one)
    y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
    return y if self.export else (y, {"one2many": x, "one2one": one2one})


def _inference(self, x: List[torch.Tensor]) -> torch.Tensor:
    # Inference path
    shape = x[0].shape  # BCHW
    x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
    if self.format != "imx" and (self.dynamic or self.shape != shape):
        self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
        self.shape = shape

    if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV ops
        #box = x_cat[:, : self.reg_max * 4]
        #cls = x_cat[:, self.reg_max * 4 :]
        box = x_cat[:, : self.reg_max * self.num_reg_targets]
        cls = x_cat[:, self.reg_max * self.num_reg_targets :]
    else:
        #box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
        box, cls = x_cat.split((self.reg_max * self.num_reg_targets, self.nc), 1)

    if self.export and self.format in {"tflite", "edgetpu"}:
        # Precompute normalization factor to increase numerical stability
        # See https://github.com/ultralytics/ultralytics/issues/7371
        grid_h = shape[2]
        grid_w = shape[3]
        grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
        norm = self.strides / (self.stride[0] * grid_size)
        dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
    elif self.export and self.format == "imx":
        dbox = self.decode_bboxes(
            self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
        )
        return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
    else:
        dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
    cls = (F.softmax(cls, dim=1) if not self.training else cls).permute(0, 2, 1)
    conf, class_id = cls.max(dim=2, keepdim=True)  # Both (B, N, 1)
    dbox = dbox.permute(0, 2, 1)
    return torch.cat((dbox, conf, class_id.float()), dim=2)

def bias_init(self):
    m = self  # Detect() module

    # Initialize regression heads (cv2)
    for a, s in zip(m.cv2, m.stride):
        a[-1].bias.data[:] = 1.0  # box regression biases

    # Initialize classification heads (task_heads)
    for idx, (b, task) in enumerate(zip(m.task_heads, self.tasks)):
        level = idx % self.nl  # which detection level this head belongs to
        stride = m.stride[level]
        num_cls = task['num_class']
        b[-1].bias.data[:num_cls] = math.log(5 / num_cls / (640 / stride) ** 2)

    # Optional: Handle end2end if needed
    if self.end2end:
        for a, s in zip(m.one2one_cv2, m.stride):
            a[-1].bias.data[:] = 1.0
        # You'd also need one2one_task_heads if end2end is active and used.

LOSS:
class FocalLoss(nn.Module):
“”"
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).

Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
on hard negatives during training.

Attributes:
    gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
    alpha (torch.Tensor): The balancing factor used to address class imbalance.
"""

def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
    """Initialize FocalLoss class with focusing and balancing parameters."""
    super().__init__()
    self.gamma = gamma
    self.alpha = torch.tensor(alpha)

def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
    """Calculate focal loss with modulating factors for class imbalance."""
    loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
    # p_t = torch.exp(-loss)
    # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

    # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
    pred_prob = pred.sigmoid()  # prob from logits
    p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
    modulating_factor = (1.0 - p_t) ** self.gamma
    loss *= modulating_factor
    if (self.alpha > 0).any():
        self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
        alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
        loss *= alpha_factor
    return loss.mean(1).sum()

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict

class v8DetectionLoss:
def init(self, code_weights=None, num_classes=10):
self.code_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
self.num_classes = num_classes

    self.cls_loss_fn = FocalLoss(alpha=0.25, gamma=2.0)
    #self.cls_loss_fn = nn.CrossEntropyLoss(reduction='mean')
    self.reg_loss_fn = nn.SmoothL1Loss(reduction='mean')
    
    # Initialize metrics storage
    self.reset_metrics()

def reset_metrics(self):
    """Reset all metrics at the start of each epoch"""
    self.metrics = {
      
        'total_reg_loss': 0.0,
        'total_cls_loss': 0.0,
        'per_class': {
            cls_id: {
                'reg_loss': 0.0,
                'cls_loss': 0.0,
                'tp': 0,    # true positives
                'fp': 0,    # false positives
                'fn': 0,    # false negatives
                'count': 0  # total ground truths
            } for cls_id in range(self.num_classes)
        }
    }


def update_metrics(self, cls_id, reg_loss, cls_loss, is_tp=False, is_fp=False, is_fn=False):
    """Update per-class metrics"""
    cls_metrics = self.metrics['per_class'][cls_id]

    reg_loss_val = reg_loss.item() if isinstance(reg_loss, torch.Tensor) else reg_loss
    cls_loss_val = cls_loss.item() if isinstance(cls_loss, torch.Tensor) else cls_loss
    cls_metrics['reg_loss'] += reg_loss_val
    cls_metrics['cls_loss'] += cls_loss_val
    cls_metrics['count'] += 1
    if is_tp:
        cls_metrics['tp'] += 1
    if is_fp:
        cls_metrics['fp'] += 1
    if is_fn:
        cls_metrics['fn'] += 1

def get_class_metrics(self, cls_id):
    """Get computed metrics for a specific class"""
    m = self.metrics['per_class'][cls_id]
    precision = m['tp'] / (m['tp'] + m['fp'] + 1e-16)
    recall = m['tp'] / (m['tp'] + m['fn'] + 1e-16)
    return {
        'reg_loss': m['reg_loss'] / max(1, m['count']),
        'cls_loss': m['cls_loss'] / max(1, m['count']),
        'precision': precision,
        'recall': recall,
        'count': m['count']
    }

def loss(self, preds_dicts, targets, **kwargs):

    self.reset_metrics()  # Reset metrics at start of batch
    
    if isinstance(preds_dicts, list):
        reshaped = []
        for p in preds_dicts:
            B, C, H, W = p.shape
            p = p.permute(0, 2, 3, 1).reshape(B, -1, C)
            reshaped.append(p)
        preds_dicts = torch.cat(reshaped, dim=1)

    B, N, _ = preds_dicts.shape
    pred_boxes = preds_dicts[:, :, :9]
    #pred_scores = preds_dicts[:, :, 9]
    #pred_classes = preds_dicts[:, :, 10].long()  # Convert to long for comparison
    pred_class_logits = preds_dicts[:, :, 10:]  # (B, N, num_classes)

    # Now get predicted classes as:
    pred_classes = torch.argmax(pred_class_logits, dim=-1)  # (B, N)


    gt_boxes = targets['gt_boxes']
    gt_classes = targets['gt_classes'].long()  # Ensure same type as pred_classes
    mask = targets.get('mask', torch.ones_like(gt_classes))  # Default all valid

    device = gt_boxes.device
    total_reg_loss = torch.tensor(0.0, device=device,requires_grad=True)
    total_cls_loss = torch.tensor(0.0, device=device,requires_grad=True)
    

    # Track which predictions have been matched to avoid duplicate counting
    matched_preds = torch.zeros(B, N, dtype=torch.bool, device=device)

    for b in range(B):
        
        valid_gt = mask[b].bool()
        gt_classes_b = gt_classes[b][valid_gt]
        gt_boxes_b = gt_boxes[b][valid_gt]
        
        if gt_boxes_b.shape[0] == 0:
            continue  # Skip images with no ground truth
        
        if b == 0:  # only print for first batch to reduce clutter
            print("\n--- Batch 0 Debug ---")
            print("Predicted raw logits (first 5):", pred_class_logits[b][:5].detach().cpu())
            pred_probs = torch.softmax(pred_class_logits[b], dim=-1)
            print("Predicted softmax probs (first 5):", pred_probs[:5].detach().cpu())
            print("Predicted classes (first 5):", pred_classes[b][:5].detach().cpu())
            print("GT classes:", gt_classes_b.detach().cpu())

        for k in range(gt_boxes_b.shape[0]):
            gt_box = gt_boxes_b[k]
            gt_class = gt_classes_b[k]

            # Find predictions of same class that haven't been matched yet
            class_match = (pred_classes[b] == gt_class) & (~matched_preds[b])
            match_idx = class_match.nonzero(as_tuple=False).squeeze(-1)

            if match_idx.numel() > 0:
                # Select prediction with highest score for this class
               # best_idx = match_idx[torch.argmax(pred_scores[b][match_idx])]
                match_logits = pred_class_logits[b][match_idx]  # (M, num_classes)
                match_probs = torch.softmax(match_logits, dim=-1)  # (M, num_classes)
                match_conf, _ = torch.max(match_probs, dim=-1)     # (M,)
                best_local_idx = torch.argmax(match_conf)          # Scalar
                best_idx = match_idx[best_local_idx] 
                if b == 0:
                    print(f"\n[GT {k}] Matched class {gt_class.item()} to pred idx {best_idx.item()}")
                    print("GT Box:", gt_box.detach().cpu().numpy())
                    print("Pred Box:", pred_boxes[b, best_idx].detach().cpu().numpy())

                matched_preds[b, best_idx] = True  # Mark as matched
                is_tp = True
                is_fp = False
            else:
                # No prediction for this class - count as false negative
                self.update_metrics(gt_class.item(), 0, 0, is_fn=True)
                continue

            pred_box = pred_boxes[b, best_idx]
            dist_thresh = 200  # in meters
            if torch.norm(pred_box[:3] - gt_box[:3]) > dist_thresh:
                # Skip bad match: count FN for GT, optionally FP for this pred
                self.update_metrics(gt_class.item(), 0, 0, is_fn=True)
                continue
            #pred_score = pred_scores[b, best_idx]
            reg_loss = self.reg_loss_fn(pred_box, gt_box)
            reg_loss = (reg_loss * self.code_weights.to(device)).mean()
            reg_loss*=0.25
            # Compute classification loss for all predictions of this GT class
            class_mask = (pred_classes[b] == gt_class) & (~matched_preds[b])
            if class_mask.any():
                cls_loss = self.cls_loss_fn(
                    pred_class_logits[b, class_mask],
                    torch.full((class_mask.sum(),), gt_class, device=device)
                )
                cls_loss = cls_loss.mean()
            else:
                # fallback in case class_mask is empty
                cls_loss = self.cls_loss_fn(
                    pred_class_logits[b, best_idx].unsqueeze(0),
                    gt_class.unsqueeze(0)
                )


            # Update metrics
            self.update_metrics(gt_class.item(), reg_loss, cls_loss, is_tp=is_tp)

            total_reg_loss =total_reg_loss+ reg_loss
            total_cls_loss =total_cls_loss+cls_loss

        # Count false positives (predictions not matched to any GT)
        #unmatched_preds = (~matched_preds[b]) & (pred_scores[b] > 0.5)  # Using 0.5 threshold
        pred_probs = torch.softmax(pred_class_logits[b], dim=-1)
        pred_confidence, _ = torch.max(pred_probs, dim=-1)
        unmatched_preds = (~matched_preds[b]) & (pred_confidence > 0.5)
        valid_unmatched = unmatched_preds & (pred_classes[b] >= 0) & (pred_classes[b] < self.num_classes)
        # Penalty for unmatched high-confidence predictions (false positives)
        fp_mask = valid_unmatched
        if fp_mask.any():
            cls_fp_loss = self.cls_loss_fn(
                pred_class_logits[b, fp_mask],
                torch.zeros_like(pred_classes[b, fp_mask], device=device)  # dummy background class
            )
            total_cls_loss = total_cls_loss + 0.5 * cls_fp_loss  # scale down FP loss

        # Still update metrics
        for cls_id in pred_classes[b][valid_unmatched].unique():
            self.update_metrics(cls_id.item(), torch.tensor(0.0, device=device), torch.tensor(0.0, device=device), is_fp=True)
        if b == 0 and valid_unmatched.sum() > 0:
            print("\nUnmatched high-confidence predictions:")
            print("Classes:", pred_classes[b][valid_unmatched].detach().cpu().tolist())
            print("Confidences:", pred_confidence[valid_unmatched].detach().cpu().tolist())
    total_loss = total_reg_loss + total_cls_loss
    
    # Update global metrics
    self.metrics['total_loss'] = total_loss.detach().item()
    self.metrics['total_reg_loss'] = total_reg_loss.detach().item()
    self.metrics['total_cls_loss'] = total_cls_loss.detach().item()
    
    return total_loss, total_reg_loss, total_cls_loss

def print_metrics(self, class_names=None):
    """Print formatted metrics for all classes"""
    if class_names is None:
        class_names = [str(i) for i in range(self.num_classes)]
    
    print("\nPer-Class Metrics:")
    print(f"{'Class':<15} {'Count':<8} {'Reg Loss':<10} {'Cls Loss':<10} {'Precision':<10} {'Recall':<10}")
    
    for cls_id in range(self.num_classes):
        metrics = self.get_class_metrics(cls_id)
        if metrics['count'] > 0:  # Only show classes with samples
            print(f"{class_names[cls_id]:<15} {metrics['count']:<8} "
                  f"{metrics['reg_loss']:.4f}    {metrics['cls_loss']:.4f}    "
                  f"{metrics['precision']:.4f}    {metrics['recall']:.4f}")
    
    print(f"\nTotal Loss: {self.metrics['total_loss']:.4f} "
          f"(Reg: {self.metrics['total_reg_loss']:.4f}, "
          f"Cls: {self.metrics['total_cls_loss']:.4f})")

MAIN FILE:
backbone_pts = PtsBackbone(**backbone_pts_conf)
fuser = MFAFuser(**fuser_conf).cuda()
backbone_img = backbone_img.cuda()
backbone_pts = backbone_pts.cuda()
fuser = fuser.cuda()
backbone_img.eval()
backbone_pts.eval()
fuser.eval()

Sample forward (inside your loop)

for batch in dataloader:
mats_dict = {
‘sensor2ego_mats’: batch[1].cuda(),
‘ego2global_mats’: batch[2].cuda(),
‘ida_mats’: batch[3].cuda(),
‘sensor2sensor_mats’: batch[4].cuda(),
‘intrin_mats’: batch[5].unsqueeze(1).unsqueeze(1).cuda(),
‘bda_mat’: batch[5].cuda()
}
img = batch[0].cuda() # images: shape (B, C, H, W)
ptss_sweep=batch[11].cuda()
ptss_context, ptss_occupancy, _ = backbone_pts(ptss_sweep)
feats,depth,_ = backbone_img(img,mats_dict,ptss_context,ptss_occupancy,return_depth=True)
fused_feats, _ = fuser(feats)
#print(“Fused BEV OUTPUT”)
#print(“Fused features shape:”, fused_feats.shape)

from utils.loss import v8DetectionLoss # adjust path as needed
from ultralytics.nn.tasks import DetectionModel
import torch
from tqdm import tqdm

model = DetectionModel(cfg=‘/home/ivision/YOLO/ultralytics/ultralytics/cfg/models/v8/yolov8.yaml’, ch=256).cuda()

x = fused_feats

from torch import optim
loss_fn = v8DetectionLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-2)
class_names = [‘car’, ‘truck’, ‘construction_vehicle’, ‘bus’, ‘trailer’,
‘barrier’, ‘motorcycle’, ‘bicycle’, ‘pedestrian’, ‘traffic_cone’,]
loss_fn = v8DetectionLoss(num_classes=len(class_names))
num_epochs = 24 # change as needed
model.train()

for epoch in range(num_epochs):
for i, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):

    optimizer.zero_grad()

    mats_dict = {
        'sensor2ego_mats': batch[1].cuda(),
        'ego2global_mats': batch[2].cuda(),
        'ida_mats': batch[3].cuda(),
        'sensor2sensor_mats': batch[4].cuda(),
        'intrin_mats': batch[5].unsqueeze(1).unsqueeze(1).cuda(),
        'bda_mat': batch[5].cuda()
    }
    imgs = batch[0].cuda()
    pts = batch[11].cuda()

    ptss_context, ptss_occupancy, _ = backbone_pts(pts)
    feats, _, _ = backbone_img(imgs, mats_dict, ptss_context, ptss_occupancy, return_depth=True)
    fused_feats, _ = fuser(feats)

    preds= model(fused_feats)
    gt_boxes = batch[8].cuda()         # (1, K, 9)
    gt_classes = batch[9].long().cuda()  # (1, K)
    

    batch_dict = {
        "gt_boxes": gt_boxes,
        "gt_classes": gt_classes
    }
    
    total_loss, reg_loss, cls_loss = loss_fn.loss(preds, batch_dict)
    total_loss.backward()
    optimizer.step()
  
    if i % 50 == 0:  # Print every 50 batches
        tqdm.write(f"\n[Epoch {epoch+1} | Batch {i+1}]")
        tqdm.write(f"Total Loss: {total_loss.item():.4f} | Reg: {reg_loss.item():.4f} | Cls: {cls_loss.item():.4f}")
        loss_fn.print_metrics(class_names)


# Optional: Save checkpoint
torch.save(model.state_dict(), f"/home/ivision/YOLO/ultralytics/ultralytics/checkpoints/yolov8_epoch{epoch}.pth")

My OUTPUT
— Batch 0 Debug —
Predicted classes (first 5): tensor([9, 3, 9, 9, 4, 9, 9, 9, 9, 9, 9, 9, 9, 5, 4, 3, 4, 5, 5, 9, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 4, 4, 4, 5, 3, 5, 5, 3, 3, 5, 5, 5, 5, 3, 5, 5, 9, 4, 7, 5, 2, 9, 4, 4, 3, 7, 3, 3, 7, 7, 7, 3, 4, 9, 9, 5, 4, 9, 4, 9, 5, 9, 9, 9, 9, 9, 9, 4, 9, 7, 9, 5, 5, 5, 5, 2, 2, 7, 7, 7, 7, 7, 8, 8, 9, 4, 9, 5, 9, 9, 9, 6,
2, 5, 8, 8, 8, 8, 8, 8, 9, 4, 4, 9, 5, 9, 9, 2, 5, 5, 8, 8, 8, 8, 8, 8, 9, 4, 4, 9, 9, 4, 5, 2, 2, 5, 8, 8, 8, 8, 8, 8, 9, 4, 9, 5, 5, 9, 5, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 4, 7, 4, 5, 9, 5, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 4, 7, 4, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 6, 4, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
5, 4, 3, 4, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 4, 2, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 4, 4, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 8, 5, 8, 8, 7, 7, 5, 5, 3, 9, 5, 5, 5, 7, 9, 7, 5, 9, 8, 5, 5, 5, 7, 5, 8, 5, 5, 5, 9, 8, 8, 7, 5, 5, 5, 7, 9, 8, 9, 9, 8, 8, 8, 8, 7, 7, 8, 9, 8, 5, 8, 9, 9, 9, 9,
9, 9, 8, 9, 9, 9, 9, 9, 4, 5, 5, 7, 5, 5, 5, 3, 5, 5, 9, 3, 8, 5, 9, 6])
GT classes: tensor([8, 0, 0])

[GT 6] Matched class 7 to pred idx 133
GT Box: [ 23.047 -14.291 1.2105 1.7485 0.50254 1.1815 1.4128 -0.1014 0.19942]
Pred Box: [ 0.79653 -0.40139 1.5749 0.24956 0.9028 1.7428 0.23286 1.7748 0.55242]

[Epoch 1 | Batch 51]
Total Loss: 3.3326 | Reg: 1.5644 | Cls: 1.7682
Epoch 1/5: 1%|█▎ | 49/6019 [00:04<07:56, 12.53it/s]
Per-Class Metrics:
Class Count Reg Loss Cls Loss Precision Recall
car 2 0.0000 0.0000 0.0000 0.0000
pedestrian 1 1.5644 1.7682 1.0000 1.0000

Total Loss: 3.3326 (Reg: 1.5644, Cls: 1.7682)
After 2nd Epoch

[Epoch 2 | Batch 19201]
Total Loss: 11.1895 | Reg: 11.1895 | Cls: 0.0000
Epoch 2/24: 68%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 19200/28130 [2:07:42<54:13, 2.75it/s]
Per-Class Metrics:
Class Count Reg Loss Cls Loss Precision Recall
car 11 1.0172 0.0000 0.9091 1.0000
truck 1 0.0000 0.0000 0.0000 0.0000
bus 1 0.0000 0.0000 0.0000 0.0000
bicycle 2 0.0000 0.0000 0.0000 0.0000
pedestrian 23 0.0000 0.0000 0.0000 0.0000

Total Loss: 11.1895 (Reg: 11.1895, Cls: 0.0000)

I think the main reason for the 0 cls loss is that v8DetectionLoss.loss() was originally designed to be called once per epoch, iterate over batches, accumulate stats and losses, then reset it’s internal state (metrics)

but in the current code snippet it’s being called once per batch, thus reseting the internal metrics dict each batch
since most predictions don’t match any ground truth (especially in early training) cls_loss_fn is rarely called leading to many batches where cls loss is exactly zero

another possible logical bug is the slicing in pred_boxes = preds_dicts[:, :, :9], I suggest printing preds_dicts.shape, I assume it isn’t 19 as your current code assume but rather 154, which is 16*9 + 10, since you’re using a distributed focal loss (: predicts logits for 16 bins, applies softmax, then compute the weighted average as the predicted continous value) which means a decoding function is needed there

1 Like