One of the Variables required has been modified by inplace operation

Hello all,
I have run into the following error and I do not know why or what causes this error -

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 120087]], which is output 0 of SelectBackward, is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

My loss functions are as follows -

class FocalLoss(nn.Module):
    # def __init__(self):

    def forward(self, classifications, regressions, anchors, annotations):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]
        anchor_widths = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:,4] != -1]
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            if bbox_annotation.shape[0] == 0:
                # print("here")
                if torch.cuda.is_available():
                    alpha_factor = torch.ones(classification.shape).cuda() * alpha

                    alpha_factor = 1.0 - alpha_factor
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    # cls_loss = focal_weight * torch.pow(bce, gamma)
                    cls_loss = focal_weight * bce
                    classification_losses.append(cls_loss.sum())
                    regression_losses.append(torch.tensor(0).float().cuda())

                else:
                    alpha_factor = torch.ones(classification.shape) * alpha

                    alpha_factor = 1.0 - alpha_factor
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    # cls_loss = focal_weight * torch.pow(bce, gamma)
                    cls_loss = focal_weight * bce
                    classification_losses.append(cls_loss.sum())
                    regression_losses.append(torch.tensor(0).float())

                continue
            IoU = calc_iou(
                anchors[0, :, :], bbox_annotation
            )  # num_anchors x num_annotations

            IoU_max, IoU_argmax = torch.max(IoU, dim=1)  # num_anchors x 1

            # import pdb
            # pdb.set_trace()

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1

            if torch.cuda.is_available():
                targets = targets.cuda()

            targets[torch.lt(IoU_max, 0.4), :] = 0

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[
                positive_indices, assigned_annotations[positive_indices, 4].long()
            ] = 1

            if torch.cuda.is_available():
                alpha_factor = torch.ones(targets.shape).cuda() * alpha
            else:
                alpha_factor = torch.ones(targets.shape) * alpha

            alpha_factor = torch.where(
                torch.eq(targets, 1.0), alpha_factor, 1.0 - alpha_factor
            )
            focal_weight = torch.where(
                torch.eq(targets, 1.0), 1.0 - classification, classification
            )
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(
                targets * torch.log(classification)
                + (1.0 - targets) * torch.log(1.0 - classification)
            )

            # cls_loss = focal_weight * torch.pow(bce, gamma)
            cls_loss = focal_weight * bce

            if torch.cuda.is_available():
                cls_loss = torch.where(
                    torch.ne(targets, -1.0),
                    cls_loss,
                    torch.zeros(cls_loss.shape).cuda(),
                )
            else:
                cls_loss = torch.where(
                    torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape)
                )

            classification_losses.append(
                cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0)
            )

            # compute the loss for regression

            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1
                gt_widths = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
                targets = targets.t()

                if torch.cuda.is_available():
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices)

                regression_diff = torch.abs(targets - regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0,
                )
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())

        return (
            torch.stack(classification_losses).mean(dim=0, keepdim=True),
            torch.stack(regression_losses).mean(dim=0, keepdim=True),
        )

class softIoU(nn.Module):
    def __init__(self):
        super(softIoU, self).__init__()
        pass
    def forward(self, cScores, predictions, gt):
        cScores = cScores.view((cScores.shape[2], cScores.shape[1]))
        IoUs = calc_iou(predictions.squeeze(0), gt.squeeze(0))
        loss = (IoUs*torch.log(cScores)) + ((1-IoUs)*torch.log(1-cScores))
        print(loss.mean())
        return loss.mean()

and I did set torch.autograd.set_detect_anomaly(True)
and it prints the following -

[W python_anomaly_mode.cpp:104] Warning: Error detected in ClampBackward. Traceback of forward call that caused the error:
  File "/home/atharva/Attentive-RetinaNet/main.py", line 41, in <module>
    trainResnet(dataloader=dataloader, epochs=500)
  File "/home/atharva/Attentive-RetinaNet/main.py", line 24, in trainResnet
    loss = model((image, gt))
  File "/home/atharva/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/atharva/Attentive-RetinaNet/retinanet/model.py", line 354, in forward
    transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)
  File "/home/atharva/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/atharva/Attentive-RetinaNet/retinanet/utils.py", line 215, in forward
    boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height)

The box utils function is as follows -

class ClipBoxes(nn.Module):
    def __init__(self, width=None, height=None):
        super(ClipBoxes, self).__init__()

    def forward(self, boxes, img):

        batch_size, num_channels, height, width = img.shape

        boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
        boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)

        boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width)
        boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height)

        return boxes

I have no clue how to proceed, any help is appreciated.

The boxes tensor is performing inplace operations that can be fixed with a .clone() call on the right hand side:

boxes[:, :, 0] = torch.clamp(boxes.clone()[:, :, 0], min=0)
boxes[:, :, 1] = torch.clamp(boxes.clone()[:, :, 1], min=0)
boxes[:, :, 2] = torch.clamp(boxes.clone()[:, :, 2], max=width)
boxes[:, :, 3] = torch.clamp(boxes.clone()[:, :, 3], max=height)

Some more discussion here:

1 Like

Thanks a lot @gbilla, it works now