How to avoid "one of the variables needed for gradient computation has been modified by an inplace operation"

Trying to implement Tversky metric as follows:

import torch
def tversky_score(
        predictions,
        targets,
        alpha=0.5,
        beta=0.5,
        eps=1e-16,
        encode_target=True
):
    """
    Ref A: https://arxiv.org/abs/1706.05721
    alpha = beta = 0.5 : Dice coefficient
    alpha = beta = 1 : Tanimoto coefficient / Jaccard Index
    alpha + beta = 1 : Produces set of F*-scores
    :param predictions:
    :param targets:
    :param alpha:
    :param beta:
    :param eps:
    :param encode_target:
    :return:
    """
    if encode_target:
        # noinspection PyArgumentList
        gt = predictions.clone().zero_()
        # noinspection PyArgumentList
        assert (targets.ndim == 4) and (
                targets.size(1) == 1
        ), "Only single channel images can be encoded!"
        gt = gt.scatter_(1, targets.to(dtype=torch.int64), 1)
        # gt.requires_grad = True
    else:
        gt = targets.clone()
        # gt.requires_grad = True
    # flatten label and prediction tensors
    predictions = predictions.view(-1)
    gt = gt.view(-1)

    # True Positives, False Positives & False Negatives
    tp = (predictions * gt).sum()
    fp = ((1 - gt) * predictions).sum()
    fn = (gt * (1 - predictions)).sum()

    tversky = (tp + eps) / (
            tp + alpha * fp + beta * fn + eps
    )
    return tversky

My loss functions are:

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=1, smooth=1e-16):
        """
        Ref A: https://arxiv.org/abs/1706.05721
        alpha = beta = 0.5, gamma = 1 : Dice coefficient
        alpha = beta = 1, gamma = 1 : Tanimoto coefficient / Jaccard Index
        alpha + beta = 1, gamma = 1 : Produces set of F*-scores
        :param alpha:
        :param beta:
        :param gamma:
        :param smooth:
        """
        super(FocalTverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.eps = smooth

    def forward(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
        encode_target: bool = True
    ):
        tversky = tversky_score(
            predictions=predictions,
            targets=targets,
            encode_target=encode_target,
            alpha=self.alpha,
            beta=self.beta,
            eps=self.eps
        )
        focal_tversky = (1 - tversky) ** self.gamma
        return focal_tversky

class PCHuberLoss(nn.Module):
    def __init__(self, reduction='mean', delta=1.0):
        super(PCHuberLoss, self).__init__()
        self.net = nn.HuberLoss(
            reduction=reduction, delta=delta
        )

    def forward(self, predictions, targets):
        return self.net(predictions.F, targets.F)

class HybridLoss(nn.Module):
    def __init__(
            self,
            alpha=0.5,
            beta=0.5,
            gamma=1,
            smooth=1e-16,
            reduction='mean',
            delta=1.0,
            p_factor=0.5,
            i_factor=0.5,
            encode_label=True
    ):
        super(HybridLoss, self).__init__()
        self.im_criterion = FocalTverskyLoss(
            alpha=alpha,
            beta=beta,
            gamma=gamma,
            smooth=smooth
        )
        self.pc_criterion = PCHuberLoss(
            reduction=reduction,
            delta=delta
        )
        self.p_f = p_factor
        self.i_f = i_factor
        self.img_encode = encode_label

    def forward(self, predictions, targets):
        loss = self.im_criterion(
            predictions=predictions['label_tensor'],
            targets=targets['label_tensor'],
            encode_target=self.img_encode
        )
        if 'label' in predictions.keys():
            loss = (self.i_f * loss) + (
                self.p_f * self.pc_criterion(
                    predictions=predictions['point_tensor'],
                    targets=targets['point_tensor']
                )
            )
        return loss

I use HybridLoss to calculate loss. In loss.backward() call I encounter RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Possibly this is due to zero_() call since it is an in place operation. However, I am not sure how can I fix this. I would appreciate some pointers.

Hi Abhisek!

This is likely the problem. (.clone() without .detach() is still part
of the computation graph, and you then .zero_() the cloned tensor
in place.)

It looks like you want to initialize gt as a zero-tensor of the same
shape as predictions. Try:

gt = torch.zeros_like (predictions)

Best.

K. Frank

Thanks for the suggestion @KFrank. I already tried that as below:

def tversky_score(
        predictions,
        targets,
        alpha=0.5,
        beta=0.5,
        eps=1e-16,
        encode_target=True
):
    if encode_target:
        # noinspection PyArgumentList
        # gt = predictions.clone()  # .zero_()
        gt = torch.zeros_like(predictions)
        # noinspection PyArgumentList
        assert (targets.ndim == 4) and (
                targets.size(1) == 1
        ), "Only single channel images can be encoded!"
        gt = gt.scatter_(1, targets.to(dtype=torch.int64), 1)
        gt.requires_grad = True
    else:
        gt = targets.clone()
        gt.requires_grad = True
    # flatten label and prediction tensors
    predictions = predictions.view(-1)
    gt = gt.view(-1)

    # True Positives, False Positives & False Negatives
    tp = (predictions * gt).sum()
    fp = ((1 - gt) * predictions).sum()
    fn = (gt * (1 - predictions)).sum()

    tversky = (tp + eps) / (
            tp + alpha * fp + beta * fn + eps
    )
    return tversky

but I still get the exact same runtime error (one of the variables needed for gradient computation has been modified by an inplace operation). Is it because of the line

 gt = gt.scatter_(1, targets.to(dtype=torch.int64), 1)

Is it necessary to perform gt.requires_grad = True since an entirely separate tensor is being created outside the graph?

Hi Abhisek!

Could you post a trimmed-down, runnable, fully-self-contained script
that reproduces the error? (Please have the script print out the version
of pytorch you’re using, torch.__version__.)

Please provide hard-coded (or random) sample data (with size and
dimensionality as small as possible while still reproducing the error).
Also delete any comments not relevant to the issue and see if you can
can get rid of the assert and if encode branch, and simplify or leave
out some of the calculations without losing the error.

In isolation, this in-place scatter_() shouldn’t break anything.

(Note that the assignment to gt is redundant because gt is modified
in-place and gt.scatter_() merely returns the original gt object, as
modified.)

No, this shouldn’t be necessary, and, in general, you don’t want to
be doing this. In this specific case, doing so shouldn’t break anything,
but, more generally, you could be asking for trouble down the line.

As a general rule, if you find you need to set requires_grad = True
for anything other than an “upstream” tensor whose value you want
optimize, it’s a sign that you’re doing something wrong.

Best.

K. Frank

@KFrank Thanks for your suggestions. Following your recommendations, I narrowed down the part causing this issue. Turns out in a layer block I used torch.nn.ReLU(True) without paying much attention to it. The issue occurred since I explicitly performed this inplace ReLU operation.

1 Like