Loss.backward() runtime error after 2nd batch iteration

Hello,
while trying to train my network I am getting this annoying error. Looks like the gradients are not working as expected or something.
Anyway I’ll be happy to provide you with more code and instructions if needed.

THANKS !

Warning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "/home/kd-6d-pose-adlp/train_kd.py", line 135, in <module>
    _, loss_dict = model(images, targets=targets, pred_t=pred_t, cfg_kd=cfg_kd)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kd-6d-pose-adlp/models/model_kd.py", line 84, in forward
    pred_cls, pred_reg, pred_evi = self.head(features)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kd-6d-pose-adlp/models/model.py", line 465, in forward
    evidential_pred = self.scales[l](self.evidential_pred(pose_tower))
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kd-6d-pose-adlp/models/model.py", line 24, in forward
    return input * self.scale
 (function _print_stack)
steps: 1/20000, lr:0.000040, cls:13.9574, reg:41.6412, kd:0.0000, evi:41.6412:   8%|▊         | 1/12 [00:31<05:50, 31.82s/it]
Traceback (most recent call last):
  File "/home/kd-6d-pose-adlp/train_kd.py", line 166, in <module>
    loss.backward(retain_graph=True)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/comet_ml/monkey_patching.py", line 317, in wrapper
    raise exception_raised
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/comet_ml/monkey_patching.py", line 288, in wrapper
    return_value = original(*args, **kwargs)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1]] is at version 2; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Based on the stacktrace it seems you are keeping the computation graph alive and are thus probably trying to use stale forward activations from the computation graph created in the first iteration in the second iteration. If you get stuck narrowing down which operation appends to the computation graph, please post a minimal and executable code snippet reproducing this issue.

Hey,
thanks for your answer.
I added

retain_graph=True

to loss.backward() when I started getting this error. My original code doesn’t have it.
Here is more context :

class Scale(nn.Module):
   def __init__(self, init_value=1.0):
       super().__init__()

       self.scale = nn.Parameter(torch.tensor([init_value], dtype=torch.float32))

   def forward(self, input):
       return input * self.scale
class PoseHead(nn.Module):
    def __init__(self, in_channel, n_class, n_conv, prior, regression_type):
        super(PoseHead, self).__init__()
        num_classes = n_class - 1
        num_anchors = 1

        self.regression_type = regression_type

        cls_tower = []
        pose_tower = []
        for i in range(n_conv):
            conv_func = nn.Conv2d

            cls_tower.append(
                conv_func(
                    in_channel,
                    in_channel,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=True
                )
            )
            cls_tower.append(nn.GroupNorm(32, in_channel))
            # cls_tower.append(nn.BatchNorm2d(in_channel))
            cls_tower.append(nn.ReLU())
            pose_tower.append(
                conv_func(
                    in_channel,
                    in_channel,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=True
                )
            )
            pose_tower.append(nn.GroupNorm(32, in_channel))
            # cls_tower.append(nn.BatchNorm2d(in_channel))
            pose_tower.append(nn.ReLU())

        self.add_module('cls_tower', nn.Sequential(*cls_tower))
        self.add_module('pose_tower', nn.Sequential(*pose_tower))
        self.cls_logits = nn.Conv2d(
            in_channel, num_anchors * num_classes, kernel_size=3, stride=1,
            padding=1
        )
        self.pose_pred = nn.Conv2d(
            in_channel, num_anchors * num_classes * 16, kernel_size=3, stride=1,
            padding=1
        )

        # Additional layer for uncertainty estimation
        self.evidential_pred = Conv2DNormalGamma(
            in_channel,
            filters = num_anchors * num_classes * 16,
            kernel_size=3,
            stride=1,
            padding=1
        )

        # initialization
        for modules in [self.cls_tower, self.pose_tower,
                        self.cls_logits, self.pose_pred, self.evidential_pred]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)

        # initialize the bias for focal loss
        prior_prob = prior
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
       

        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])

    def forward(self, x):
        with torch.autograd.set_detect_anomaly(True):
            logits = []
            pose_reg = []
            evidential_reg = []  # Store uncertainty estimates mu, v, alpha and beta
            centerness = []
            for l, feature in enumerate(x):
                cls_tower = self.cls_tower(feature)
                pose_tower = self.pose_tower(feature)

                logits.append(self.cls_logits(cls_tower))

                pose_pred = self.scales[l](self.pose_pred(pose_tower))
                pose_reg.append(pose_pred)

                # Predict uncertainty
                evidential_pred = self.scales[l](self.evidential_pred(pose_tower))
                evidential_reg.append(evidential_pred)

            return logits, pose_reg, evidential_reg

Could you describe why you’ve added retain_graph=True as it’s common to raise such issues if this argument is not needed?

The original code doesn’t need it. I added it the first time I got the error trying to override the error (which didn’t work obviously). Removing it doesn’t solve the problem tho. I also checked if there is some not “in place” operations in my code and couldn’t find so it is not coming from that also.

Using retain_graph=True to mask another error won’t work so I assume the original error you’ve seen was pointing to Autograd trying to call backward on an already freed computation graph? If so, check which operation is appending to the previous computation graph.

Hey again, thanks for your answer. I solved it, partially I guess, now I have another issue. The memory usage increases after each iteration and never stops (Until the training process is shutdown).

            loss = loss_cls + loss_reg + loss_evi
            optimizer.zero_grad()
            loss.backward()
            del loss
            nn.utils.clip_grad_norm_(model.parameters(), cfg['SOLVER']['GRAD_CLIP'])
            optimizer.step()
            scheduler.step()

I tried to add dell loss but it does not solve it. I also tried loss_cls.item() + loss_reg.item() + loss_evi.item() but then torch doesn’t accept it.

Check if you are appending the loss or another tensor attached to the computation graph somewhere as it’s a common cause of an unexpected memory increase.

I can provide more context, maybe it will help because honestly I don’t see it :

scaling_factor = 50 # 0.02d
losses = nn.SmoothL1Loss(reduction='none')(scaling_factor * px, scaling_factor * target_3D_in_camera_frame).view(cellNum, -1)
losses = losses / scaling_factor

evi_loss, reg_loss = EvidentialRegression(losses, target_2D, pred_filtered, pred_v_filtered, pred_alpha_filtered, pred_beta_filtered, cellNum)
evi_loss = evi_loss.mean(dim=1)
reg_loss = reg_loss.mean(dim=1)

def EvidentialRegression(error, y_true, gamma, v, alpha, beta, cellNum, coeff=1.0):

    gamma = gamma.mean(dim=1).unsqueeze(1).repeat(1, 16).detach()

    loss_evi = NIG_NLL(y_true, gamma, v, alpha, beta, reduce=False)
    
    loss_reg = NIG_Reg(error, y_true, gamma, v, alpha, beta, reduce=False, kl=False)

    return loss_evi, coeff * loss_reg

def NIG_NLL(y, gamma, v, alpha, beta, reduce=True):
    twoBlambda = 2 * beta * (1 + v)
    nll = 0.5 * torch.log(np.pi / v) - alpha * torch.log(twoBlambda) + (alpha + 0.5) * torch.log(v * (y - gamma)**2 + twoBlambda) + torch.lgamma(alpha) - torch.lgamma(alpha + 0.5)
    return torch.mean(nll, dim=0) if reduce else nll

def NIG_Reg(error, gamma, v, alpha, beta, omega=0.01, reduce=True, kl=False):
    error[:, :8] = (error[:, :8] + error[:, 16:24]) / 2
    error[:, 8:16] = (error[:, 8:16] + error[:, 16:24]) / 2
    error = torch.cat((error[:, :8], error[:, 8:16]), dim=1).detach()

    if kl:
        kl = KL_NIG(gamma, v, alpha, beta, gamma, omega, 1 + omega, beta)
        reg = error * kl
    else:
        evi = 2 * v + alpha
        reg = error * evi

    return torch.mean(reg) if reduce else reg

Thanks again!!

The loss calculations themselves might not be too interesting, but how these tensors are stored or used afterwards. In your latest code snippet, how are evi_loss and reg_loss used? Do you see an expected memory usage if you delete these tensors directly and is something else failing (e.g. where these losses might be attached to)?