Sudden big loss in a branch makes losses in other branches not optimized any more

I’m training an object detection model with YOLO structure:

class YOLO(nn.Module):
    def __init__(self, ...)
        self.backbone = resnet50(pretrained=True)  # pretrained weights, classification layer already removed
        self.neck = nn.Sequential(
            conv3x3(in_planes=2048, out_planes=2048),
            nn.ReLU(),
            conv3x3(in_planes=2048, out_planes=2048),
            nn.ReLU(),
            conv3x3(in_planes=2048, out_planes=2048),,
            nn.ReLU(),
        )
        self.heads = nn.ModuleList(
            [YOLOHead(num_classes) for anchor in self.anchors]
        )
    def forward(self, x, ...):
        x = self.backbone(x)
        x = self.neck(x)
        outs = [head(x) for head in self.heads]

class YOLOHead(nn.Module):
    def __init__(self, num_classes):
        super(YOLOHead, self).__init__()
        self.objectness = conv1x1(in_planes=2048, out_planes=1)
        self.cls = conv1x1(in_planes=2048, out_planes=1+num_classes)
        self.reg = conv1x1(in_planes=2048, out_planes=4)

    def forward(self, x):
        objectness = self.objectness(x)
        cls = self.cls(x)
        reg = self.reg(x)
        return objectness, cls, reg

I optimize the sum of loss of (1)objectness, (2)classification and (3)bbox regression. As shown above, the heads compute them separately, yet with a shared backbone.
The issue is that, sometimes regression loss (mse_loss between gt and pred) may get a sudden raise, and then optimzed down to a moderate level again. But after that optimzing-down, loss of objectness and classification can hardly be optimized any more. My words may be vague but it looks like this (see the stable obj_loss and cls_loss after batch 26):

[epoch: 1, batch: 1] loss: 7.862009, obj_loss:  1.040800, cls_loss:  4.567673, reg_loss:  2.253536
[epoch: 1, batch: 2] loss: 7.364972, obj_loss:  1.027697, cls_loss:  4.538699, reg_loss:  1.798576
[epoch: 1, batch: 3] loss: 7.077137, obj_loss:  1.006557, cls_loss:  4.509902, reg_loss:  1.560679
[epoch: 1, batch: 4] loss: 6.991296, obj_loss:  0.959222, cls_loss:  4.449108, reg_loss:  1.582966
[epoch: 1, batch: 5] loss: 7.313764, obj_loss:  0.869381, cls_loss:  4.270252, reg_loss:  2.174131
[epoch: 1, batch: 6] loss: 6.717353, obj_loss:  0.832326, cls_loss:  4.162328, reg_loss:  1.722699
[epoch: 1, batch: 7] loss: 7.118721, obj_loss:  0.969238, cls_loss:  3.606729, reg_loss:  2.542755
[epoch: 1, batch: 8] loss: 6.670866, obj_loss:  0.907451, cls_loss:  3.944198, reg_loss:  1.819217
[epoch: 1, batch: 9] loss: 7.152078, obj_loss:  0.937133, cls_loss:  4.139555, reg_loss:  2.075390
[epoch: 1, batch: 10] loss: 6.950064, obj_loss:  0.961618, cls_loss:  4.229391, reg_loss:  1.759055
[epoch: 1, batch: 11] loss: 6.504709, obj_loss:  0.969900, cls_loss:  4.112210, reg_loss:  1.422600
[epoch: 1, batch: 12] loss: 6.989936, obj_loss:  0.993160, cls_loss:  3.952706, reg_loss:  2.044071
[epoch: 1, batch: 13] loss: 6.952768, obj_loss:  0.883456, cls_loss:  3.616553, reg_loss:  2.452760
[epoch: 1, batch: 14] loss: 7.013877, obj_loss:  0.963605, cls_loss:  4.107985, reg_loss:  1.942287
[epoch: 1, batch: 15] loss: 6.275753, obj_loss:  0.981065, cls_loss:  3.834120, reg_loss:  1.460567
[epoch: 1, batch: 16] loss: 6.559465, obj_loss:  0.931920, cls_loss:  3.854708, reg_loss:  1.772837
[epoch: 1, batch: 17] loss: 6.501263, obj_loss:  0.988465, cls_loss:  4.012724, reg_loss:  1.500074
[epoch: 1, batch: 18] loss: 6.336897, obj_loss:  0.972874, cls_loss:  3.923963, reg_loss:  1.440061
[epoch: 1, batch: 19] loss: 7.756456, obj_loss:  0.907921, cls_loss:  4.357026, reg_loss:  2.491509
[epoch: 1, batch: 20] loss: 6.757920, obj_loss:  1.009806, cls_loss:  4.316257, reg_loss:  1.431857
[epoch: 1, batch: 21] loss: 6.218307, obj_loss:  0.975769, cls_loss:  4.141428, reg_loss:  1.101111
[epoch: 1, batch: 22] loss: 11.127349, obj_loss:  1.423922, cls_loss:  6.209508, reg_loss:  3.493918
[epoch: 1, batch: 23] loss: 7.093385, obj_loss:  1.035403, cls_loss:  4.525311, reg_loss:  1.532671
[epoch: 1, batch: 24] loss: 7.644183, obj_loss:  1.034627, cls_loss:  4.520189, reg_loss:  2.089367
[epoch: 1, batch: 25] loss: 6.772291, obj_loss:  1.016327, cls_loss:  4.244712, reg_loss:  1.511253
[epoch: 1, batch: 26] loss: 193.881104, obj_loss:  1.830120, cls_loss:  17.074070, reg_loss:  174.976913
[epoch: 1, batch: 27] loss: 7.431651, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.825148
[epoch: 1, batch: 28] loss: 7.192277, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.585774
[epoch: 1, batch: 29] loss: 6.893618, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.287114
[epoch: 1, batch: 30] loss: 7.638692, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  2.032189
[epoch: 1, batch: 31] loss: 7.449893, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.843388
[epoch: 1, batch: 32] loss: 7.655496, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.048993
[epoch: 1, batch: 33] loss: 7.145531, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.539027
[epoch: 1, batch: 34] loss: 7.390171, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.783668
[epoch: 1, batch: 35] loss: 7.374976, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.768474
[epoch: 1, batch: 36] loss: 7.157423, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.550920
[epoch: 1, batch: 37] loss: 7.023059, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.416556
[epoch: 1, batch: 38] loss: 7.025788, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.419284
[epoch: 1, batch: 39] loss: 7.916700, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.310196
[epoch: 1, batch: 40] loss: 7.503070, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.896566
[epoch: 1, batch: 41] loss: 7.642205, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.035701
[epoch: 1, batch: 42] loss: 7.002790, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.396288
[epoch: 1, batch: 43] loss: 7.220839, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.614336
[epoch: 1, batch: 44] loss: 6.997747, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.391244
[epoch: 1, batch: 45] loss: 7.289024, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.682521
[epoch: 1, batch: 46] loss: 8.104959, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.498456
[epoch: 1, batch: 47] loss: 7.818621, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.212118
[epoch: 1, batch: 48] loss: 7.451015, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.844513
[epoch: 1, batch: 49] loss: 7.190177, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.583674
[epoch: 1, batch: 50] loss: 7.414326, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.807823
[epoch: 1, batch: 51] loss: 7.082467, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.475964
[epoch: 1, batch: 52] loss: 8.054209, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.447706
[epoch: 1, batch: 53] loss: 7.361894, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.755390
[epoch: 1, batch: 54] loss: 7.404776, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.798274
[epoch: 1, batch: 55] loss: 7.673492, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.066990
[epoch: 1, batch: 56] loss: 7.255838, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.649335
[epoch: 1, batch: 57] loss: 7.364078, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.757576
[epoch: 1, batch: 58] loss: 7.672638, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.066136
[epoch: 1, batch: 59] loss: 6.943457, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.336953
[epoch: 1, batch: 60] loss: 7.074402, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.467899
[epoch: 1, batch: 61] loss: 7.573625, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.967122
[epoch: 1, batch: 62] loss: 6.797518, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.191013
[epoch: 1, batch: 63] loss: 7.667085, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  2.060580
[epoch: 1, batch: 64] loss: 7.461186, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.854683
[epoch: 1, batch: 65] loss: 7.218839, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.612336
[epoch: 1, batch: 66] loss: 6.966671, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.360166
[epoch: 1, batch: 67] loss: 7.419224, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.812721
[epoch: 1, batch: 68] loss: 7.038870, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.432365
[epoch: 1, batch: 69] loss: 7.527884, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.921382
[epoch: 1, batch: 70] loss: 6.735312, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.128809
[epoch: 1, batch: 71] loss: 7.051725, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.445222
[epoch: 1, batch: 72] loss: 7.819386, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.212884
[epoch: 1, batch: 73] loss: 7.343047, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.736544
[epoch: 1, batch: 74] loss: 6.986712, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.380208
[epoch: 1, batch: 75] loss: 7.757770, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.151267
[epoch: 1, batch: 76] loss: 7.547459, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.940956
[epoch: 1, batch: 77] loss: 7.939866, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.333363
[epoch: 1, batch: 78] loss: 7.063341, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.456838
[epoch: 1, batch: 79] loss: 7.327311, obj_loss:  1.039721, cls_loss:  4.566783, reg_loss:  1.720806
[epoch: 1, batch: 80] loss: 7.129947, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.523444
[epoch: 1, batch: 81] loss: 7.087235, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.480733
[epoch: 1, batch: 82] loss: 7.572030, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.965527
[epoch: 1, batch: 83] loss: 7.192766, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.586263
[epoch: 1, batch: 84] loss: 7.427009, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.820506
[epoch: 1, batch: 85] loss: 7.694969, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.088466
[epoch: 1, batch: 86] loss: 8.016901, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  2.410399
[epoch: 1, batch: 87] loss: 7.126433, obj_loss:  1.039721, cls_loss:  4.566782, reg_loss:  1.519930

It’s the first epoch. Loss obj_loss and cls_loss is unlikely to be optimized after batch 26, yet reg_loss remains shifting.
The training script:

    device = torch.device("cuda")
    dataloader_train = get_dataloader(train=True, batch_size=batch_size)
    dataset_val = VOCDetectionDataset(root="/data/sfy_projects/Datasets/VOC2007/VOCtrainval_06-Nov-2007", year="2007",
                                      image_set="val",
                                      transforms=get_transforms(train=False),
                                      show=False)
    coco_anno_path = "/data/sfy_projects/Datasets/VOC2007/VOCtrainval_06-Nov-2007/voc2007.val.cocoformat.json"

    model = YOLO(anchors=k_means_anchors["2007.train"], num_classes=20).to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

    for epoch in range(epochs):
        print("----------------------  TRAINING  ---------------------- ")
        model.train()
        model.freeze_backbone(layers=3)
        model.freeze_bn(model.backbone)
        running_loss, running_objectness_loss, running_classification_loss, running_regression_loss = 0.0, 0.0, 0.0, 0.0
        for i, data in enumerate(dataloader_train):
            imgs, targets = data
            imgs = imgs.to(device)
            for target in targets:
                for obj in target["objects"]:
                    obj["class"] = obj["class"].to(device)
                    obj["bbox"] = obj["bbox"].to(device)

            res = model(imgs, targets)
            loss = res["loss"]["loss"]
            objectness_loss = res["loss"]["objectness_loss"]
            classification_loss = res["loss"]["classification_loss"]
            regression_loss = res["loss"]["regression_loss"]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_objectness_loss += objectness_loss.item()
            running_classification_loss += classification_loss.item()
            running_regression_loss += regression_loss.item()
            if (i + 1) % batches_show == 0:
                print(f"[epoch: {epoch + 1}, batch: {i + 1}] loss: {running_loss / batches_show:.6f}, obj_loss: {running_objectness_loss / batches_show: .6f}, cls_loss: {running_classification_loss / batches_show: .6f}, reg_loss: {running_regression_loss / batches_show: .6f}")
                running_loss, running_objectness_loss, running_classification_loss, running_regression_loss = 0.0, 0.0, 0.0, 0.0
        lr_scheduler.step(loss)

I’ve tried to ommit optimzing regression loss. Finally I got a rather good model which points out the object’s location and class without awareness of it’s bbox size. But… uhm, I’m not sure what can this prove.
My biggest doubt is that, since I already called optimizer.zero_grad(), why can something happend in one batch of training have a lasting effect on the following training batches? Also I cannot find any clue why this happens.
BTW, the problem is reproducible to me. In fact, it happens at the first epoch every time if I set the learning rate as 0.01. However the sudden raise of regression loss never show up again yet since I tuned learning rate to 0.001(as far as I’ve seen). Is all of this just about learning rate?