Torch.max(), inplace operation

I have been trying to figure this issue out by myself for a day but I have no idea what is going on here. The detailed error is shown below:

K:\Anaconda\envs\cv\lib\site-packages\torch\autograd\__init__.py:197: UserWarning: Error detected in MaxBackward0. Traceback of forward call that caused the error:
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train.py", line 166, in <module>
    main(args)
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train.py", line 94, in main
    mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader,
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train_utils\train_eval_utils.py", line 33, in train_one_epoch
    loss_dict = model(images, cats, bboxes)
  File "K:\Anaconda\envs\cv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\models\YOLO1.py", line 227, in forward
    pred_max_conf, max_conf_idx = torch.max(pred_conf, dim=-1)
  File "K:\Anaconda\envs\cv\lib\site-packages\torch\fx\traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train.py", line 166, in <module>
    main(args)
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train.py", line 94, in main
    mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader,
  File "K:\BaiduSyncdisk\CodeProject\pythonProject\DeepLearningForComputerVision\ObjectDetection\train_utils\train_eval_utils.py", line 55, in train_one_epoch
    losses.backward()
  File "K:\Anaconda\envs\cv\lib\site-packages\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "K:\Anaconda\envs\cv\lib\site-packages\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [98]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

And here is the code produce this error (a part of it). It’s in a for loop, and the pred_conf is a [98, 2] tensor, I use torch.max() to produce the max value and the corresponding index along the last axis, which is the pred_max_conf[98, ] and max_conf_idx[98, ] respectively. The mask is a [98, ] boolean tensor, which is used to point out the position with / without objects. The self.conf_loss is used to calculate the loss between pred and gt, and then accumulate along batch.
I dont think I have any inplace operations in pred_max_conf, so I’m so confusing.
I’m very thankful if anybody can help!!!

        for pred_bbox, pred_conf, gt_bbox, gt_conf, mask in zip(pred_bboxes, pred_confs,
                                                                gt_bboxes, gt_confs,
                                                                masks):
            # todo: one of the variables needed for gradient computation has been modified by an inplace operation
            pred_max_conf, max_conf_idx = torch.max(pred_conf, dim=-1)
            max_conf_idx *= 5
            # confidence loss
            # balance loss between grids w and w/o objects
            a1, a2 = torch.sum(mask) / mask.size(0), torch.sum(~mask) / mask.size(0)

            loss_conf_w_obj = loss_conf_w_obj + self.conf_loss(pred_max_conf[mask], gt_conf[mask]) * a2
            loss_conf_wo_obj = loss_conf_wo_obj + self.conf_loss(pred_max_conf[~mask], gt_conf[~mask]) * a1

            # only count images with front ground and then filter grids without any objects
            gt_bbox = gt_bbox[mask]
            pred_bbox = pred_bbox[mask]
            if pred_bbox.numel() > 0:
                pred_bbox = self.boxer.filter_bboxes(pred_bbox, max_conf_idx[mask])
                # coord loss
                loss_coord = loss_coord + self.coord_loss(pred_bbox, gt_bbox)
        loss_coord = loss_coord / B
        loss_conf_w_obj = loss_conf_w_obj / B
        loss_conf_wo_obj = loss_conf_wo_obj / B
        loss_class = self.ce_loss(pred_classes.permute(0, 2, 1), gt_cats.long())

        total_loss = 5 * loss_coord + 0.5 * loss_conf_wo_obj + loss_conf_w_obj + loss_class

the function self.conf_loss()

    @staticmethod
    def conf_loss(pred, gt):
        pred = torch.sigmoid(pred)
        loss = (pred - gt) ** 2

        return torch.sum(loss.reshape(-1)) / loss.size(0)

the pytorch version is 1.13.0

Hi Junhao!

*= is an in-place operation. From what you’ve said, max_conf_idx will
be a LongTensor of shape [98], so this is almost certainly your culprit.

As a work-around you could try max_conf_idx = 5 * max_conf_idx.

(This version creates a new tensor and sets the python reference
max_conf_idx to refer to it, leaving the old version of max_conf_idx
unchanged for use in that part of the backward pass.)

Best.

K. Frank

Thank you so much! It works! But I’m quit confusing.
max_conf_idx.requires_grad == False, I think the max_conf_idx is not used for gradient computation, that’s why I haven’t thought this problem is raised by the max_conf_idx.

Hi Junhao!

Autograd does not backpropagate through max_conf_idx (that is, it does
not track what upstream variables max_conf_idx might depend on), but it
does use max_conf_idx in the gradient computation.

Consider:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> t = torch.ones (3, requires_grad = True)
>>> idx = torch.tensor ([2, 1, 0])   # long tensor, no grad possible
>>> f = torch.tensor ([3.0, 5.0, 7.0])
>>>
>>> loss = (f * t[idx]).sum()
>>> loss.backward()
>>> t.grad   # backward uses idx to map grads to the correct elements of t
tensor([7., 5., 3.])
>>>
>>> t.grad = None
>>>
>>> loss = (f * t[idx]).sum()
>>> idx *= 0   # in-place operation on idx -- correct mapping is lost
>>> loss.backward()   # raises inplace error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [3]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

(Note, idx is a LongTensor, so it can’t have requires_grad = True and
you can’t backpropagate through it, but it is part of the forward pass and is
still needed to perform the gradient computation correctly.)

Best.

K. Frank