loss_cls does not drop

I quote ghm loss in ssd, loss_bbox can normally drop to about 0.8, but loss_cls can only drop to about 3, what’s going on?code show as below:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init

from mmdet.core import AnchorGenerator, anchor_target, multi_apply
from .anchor_head import AnchorHead
from …losses import smooth_l1_loss
from …registry import HEADS

def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)
return bin_labels, bin_label_weights

def GHM_Class_loss(pred, target, label_weight):
bins = 30
mmt = 0.75
loss_weight = 1.0
edges = torch.arange(bins + 1).float().cuda() / bins
edges[-1] += 1e-6
if mmt > 0:
acc_sum = torch.zeros(bins).cuda()
use_sigmoid = True
if not use_sigmoid:
raise NotImplementedError

if pred.dim() != target.dim():
    target, label_weight = _expand_binary_labels(
        target, label_weight, pred.size(-1))
target, label_weight = target.float(), label_weight.float()
weights = torch.zeros_like(pred)    

# gradient length
g = torch.abs(pred.sigmoid().detach() - target)

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0  # n valid bins
for i in range(bins):
    inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        if mmt > 0:
            acc_sum[i] = mmt * acc_sum[i] + (1 - mmt) * num_in_bin
            weights[inds] = tot / acc_sum[i]
        else:
            weights[inds] = tot / num_in_bin
        n += 1
if n > 0:
    weights = weights / n

loss = F.binary_cross_entropy_with_logits(
    pred, target, weights, reduction='sum') / tot
return loss * loss_weight

def GHM_Reg_loss(pred, target, label_weight, avg_factor=None):
mu = 0.02
mmt = 0.70
bins = 10
loss_weight = 1.0
edges = torch.arange(bins + 1).float().cuda() / bins
edges[-1] = 1e3
if mmt > 0:
acc_sum = torch.zeros(bins).cuda()

# ASL1 loss
diff = pred - target
loss = torch.sqrt(diff * diff + mu * mu) - mu

# gradient length
g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
weights = torch.zeros_like(g)

valid = label_weight > 0
tot = max(label_weight.float().sum().item(), 1.0)
n = 0  # n: valid bins
for i in range(bins):
    inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        n += 1
        if mmt > 0:
            acc_sum[i] = mmt * acc_sum[i] \
                    + (1 - mmt) * num_in_bin
            weights[inds] = tot / acc_sum[i]
        else:
            weights[inds] = tot / num_in_bin
if n > 0:
    weights /= n

loss = loss * weights
loss = loss.sum() / tot
return loss * loss_weight

TODO: add loss evaluator for SSD

@HEADS.register_module
class SSDHead(AnchorHead):

def __init__(self,
             input_size=300,
             num_classes=81,
             in_channels=(512, 1024, 512, 256, 256, 256),
             anchor_strides=(8, 16, 32, 64, 100, 300),
             basesize_ratio_range=(0.1, 0.9),
             anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
             target_means=(.0, .0, .0, .0),
             target_stds=(1.0, 1.0, 1.0, 1.0)):
    super(AnchorHead, self).__init__()
    self.input_size = input_size
    self.num_classes = num_classes
    self.in_channels = in_channels
    self.cls_out_channels = num_classes
    num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
    reg_convs = []
    cls_convs = []
    for i in range(len(in_channels)):
        reg_convs.append(
            nn.Conv2d(
                in_channels[i],
                num_anchors[i] * 4,
                kernel_size=3,
                padding=1))
        cls_convs.append(
            nn.Conv2d(
                in_channels[i],
                num_anchors[i] * num_classes,
                kernel_size=3,
                padding=1))
    self.reg_convs = nn.ModuleList(reg_convs)
    self.cls_convs = nn.ModuleList(cls_convs)

    min_ratio, max_ratio = basesize_ratio_range
    min_ratio = int(min_ratio * 100)
    max_ratio = int(max_ratio * 100)
    step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
    min_sizes = []
    max_sizes = []
    for r in range(int(min_ratio), int(max_ratio) + 1, step):
        min_sizes.append(int(input_size * r / 100))
        max_sizes.append(int(input_size * (r + step) / 100))
    if input_size == 300:
        if basesize_ratio_range[0] == 0.15:  # SSD300 COCO
            min_sizes.insert(0, int(input_size * 7 / 100))
            max_sizes.insert(0, int(input_size * 15 / 100))
        elif basesize_ratio_range[0] == 0.2:  # SSD300 VOC
            min_sizes.insert(0, int(input_size * 10 / 100))
            max_sizes.insert(0, int(input_size * 20 / 100))
    elif input_size == 512:
        if basesize_ratio_range[0] == 0.1:  # SSD512 COCO
            min_sizes.insert(0, int(input_size * 4 / 100))
            max_sizes.insert(0, int(input_size * 10 / 100))
        elif basesize_ratio_range[0] == 0.15:  # SSD512 VOC
            min_sizes.insert(0, int(input_size * 7 / 100))
            max_sizes.insert(0, int(input_size * 15 / 100))
    self.anchor_generators = []
    self.anchor_strides = anchor_strides
    for k in range(len(anchor_strides)):
        base_size = min_sizes[k]
        stride = anchor_strides[k]
        ctr = ((stride - 1) / 2., (stride - 1) / 2.)
        scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
        ratios = [1.]
        for r in anchor_ratios[k]:
            ratios += [1 / r, r]  # 4 or 6 ratio
        anchor_generator = AnchorGenerator(
            base_size, scales, ratios, scale_major=False, ctr=ctr)
        indices = list(range(len(ratios)))
        indices.insert(1, len(indices))
        anchor_generator.base_anchors = torch.index_select(
            anchor_generator.base_anchors, 0, torch.LongTensor(indices))
        self.anchor_generators.append(anchor_generator)

    self.target_means = target_means
    self.target_stds = target_stds
    self.use_sigmoid_cls = False
    self.cls_focal_loss = False
    self.fp16_enabled = False

def init_weights(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            xavier_init(m, distribution='uniform', bias=0)

def forward(self, feats):
    cls_scores = []
    bbox_preds = []
    for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
                                        self.cls_convs):
        cls_scores.append(cls_conv(feat))
        bbox_preds.append(reg_conv(feat))
    return cls_scores, bbox_preds
    


def loss_single(self, cls_score, bbox_pred, labels, label_weights,
                bbox_targets, bbox_weights, num_total_samples, cfg):
    loss_cls = GHM_Class_loss(
        cls_score, 
        labels, 
        label_weights)
    

    loss_bbox = GHM_Reg_loss(
        bbox_pred,
        bbox_targets,
        bbox_weights,
        avg_factor=None)
    return loss_cls[None], loss_bbox

def loss(self,
         cls_scores,
         bbox_preds,
         gt_bboxes,
         gt_labels,
         img_metas,
         cfg,
         gt_bboxes_ignore=None):
    featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
    assert len(featmap_sizes) == len(self.anchor_generators)

    anchor_list, valid_flag_list = self.get_anchors(
        featmap_sizes, img_metas)
    cls_reg_targets = anchor_target(
        anchor_list,
        valid_flag_list,
        gt_bboxes,
        img_metas,
        self.target_means,
        self.target_stds,
        cfg,
        gt_bboxes_ignore_list=gt_bboxes_ignore,
        gt_labels_list=gt_labels,
        label_channels=1,
        sampling=False,
        unmap_outputs=False)
    if cls_reg_targets is None:
        return None
    (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
     num_total_pos, num_total_neg) = cls_reg_targets

    num_images = len(img_metas)
    all_cls_scores = torch.cat([
        s.permute(0, 2, 3, 1).reshape(
            num_images, -1, self.cls_out_channels) for s in cls_scores
    ], 1)
    all_labels = torch.cat(labels_list, -1).view(num_images, -1)
    all_label_weights = torch.cat(label_weights_list,
                                  -1).view(num_images, -1)
    all_bbox_preds = torch.cat([
        b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
        for b in bbox_preds
    ], -2)
    all_bbox_targets = torch.cat(bbox_targets_list,
                                 -2).view(num_images, -1, 4)
    all_bbox_weights = torch.cat(bbox_weights_list,
                                 -2).view(num_images, -1, 4)

    losses_cls, losses_bbox = multi_apply(
        self.loss_single,
        all_cls_scores,
        all_bbox_preds,
        all_labels,
        all_label_weights,
        all_bbox_targets,
        all_bbox_weights,
        num_total_samples=num_total_pos,
        cfg=cfg)
    return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)