Questions regarding torch.no_grad for processing targets in loss forward method

I want to do a binary boundary losses inspired from https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/boundary_loss.py and https://arxiv.org/abs/1812.07032. I have seen the use of with torch.no_grad in forward method in loss. I am not sure how torch.no_grad is recommended in the forward method. Since it is done only in the labels which are not yet attach to the computational graph does it is a good idea to use it. Will it influence training.

Is better to have multiple inputs like def forward(self, preds, targets, distance_map):?

"""
Inspired from https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/boundary_loss.py.
Which list a lot of pytorch losses for segmentatoin.
"""

import numpy as np
from scipy.ndimage import distance_transform_edt
import torch


def compute_edts_forhdloss(segmentation):
    res = np.zeros(segmentation.shape)
    for i in range(segmentation.shape[0]):
        posmask = segmentation[i]
        negmask = ~posmask
        res[i] = distance_transform_edt(posmask) + distance_transform_edt(negmask)
    return res


class BinaryBDLoss(torch.nn.Module):
    def __init__(self, num_channel: int, ndim=4, per_channel: bool = True, reduction: str = "mean"):
        """
        compute boudary loss
        only compute the loss of foreground
        ref: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L74
        """
        super().__init__()
        assert ndim in [4, 5]
        self.per_channel = per_channel
        self.num_channel = num_channel
        self.ndim = ndim
        self.axis = list(range(ndim))
        self.reduction = reduction

        if self.num_channel > 1 & self.per_channel:
            if self.ndim == 5:
                self.axis = [0, -3, -2, -1]
            if self.ndim == 4:
                self.axis = [0, -2, -1]
        # self.do_bg = do_bg

    def forward(self, preds, targets):
        """
        preds: (batch_size, class, x,y)
        target: ground truth, shape: (batch_size, 1, x,y)
        bound: precomputed distance map, shape (batch_size, class, x,y)
        """
        if preds.shape[1] != self.num_channel:
            preds = preds.view(preds.shape[0], -1, *preds.shape[1:])
            targets = targets.view(*preds.shape)
            # distance = distance.view(*preds.shape)

        with torch.no_grad():
            bound = compute_edts_forhdloss(targets.cpu().numpy() > 0.5)

        bound = torch.from_numpy(bound)
        bound.requires_grad = True
        bound.to(preds.device)
        bound.type(preds.dtype)

        # print('preds shape: ', preds.shape)
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        # self.idc: List[int] = kwargs["idc"]
        # pc = preds[:, self.idc, ...].type(torch.float32)
        # dc = bound[:, self.idc, ...].type(torch.float32)

        if preds.ndim == 5:
            multipled = torch.einsum("bcxyz,bcxyz->bcxyz", preds, bound)
        else:
            multipled = torch.einsum("bcwh,bcwh->bcwh", preds, bound)

        bd_loss = torch.mean(multipled, dim=self.axis)

        if self.reduction == 'mean':
            bd_loss = bd_loss.mean()
            return bd_loss

        if self.reduction == 'sum':
            bd_loss = bd_loss.sum()
            return bd_loss

        if self.reduction == 'none':
            return bd_loss

        raise Exception('Unexpected reduction {}'.format(self.reduction))

I assume you are referring to these lines of code:

with torch.no_grad():
    bound = compute_edts_forhdloss(targets.cpu().numpy() > 0.5)

If so, the torch.no_grad() wrapper shouldn’t make any difference, since you are passing numpy arrays to the method, so Autograd won’t be able to track these operations anyway.
Generally, the guard is used to avoid gradient computation explicitly and thus not creating the computation graph. Autograd won’t track the operations inside the block. If no used tensors require gradients, using this guard shouldn’t make any difference.

1 Like

Good to know. Is it a good pratic to convert label into numpy and converting back to tensor? Since the autograd operation is make on weight in the network and target are not yet link to the network? Will it pose any problem? I guess it will on the network prediction but not on the ground truth target.

It shouldn’t be necessary to convert the tensor to a numpy array and transform it back and I would avoid doing so.
If no tensors applied on the target tensor require gradients, the no_grad() block would also not be necessary.

1 Like

The problem is the distance_transform_edt from scipy ndimage. Which convert it back to a numpy array anyway.

I already remove the torch.no_grad and replace bound = compute_edts_forhdloss(targets.cpu().numpy() > 0.5) by bound = compute_edts_forhdloss(targets > 0.5) which was not necessary, the result is still a numpy array. If I understand correctly even if it is on the target you recommend to not use directly distance_transform_edt directly in forward pass. And find another way (through using distance mask instead of segmentation mask has the 2nd parameter in forward pass or having distance mask has a thIrd parameter to the forward pass).

The reason I was doing it this way was to have to pass only the segmentation mask has a ground truth parameter to the forward pass and to be able to easily combine with a regional loss (like Dice loss) like in the article. So I do not have to manage multiple ground truth for one input. But I guess, you are suggesting me to do it the other way (preprocessing distance mask in the dataset and passing distance mask has a parameter to forward pass anyway).

Thanks for your input.

If you need to call a numpy method e.g. since there is no corresponding method in PyTorch, your workflow is correct. The no_grad() wrapper wouldn’t change anything, which was the initial question, since Autograd won’t track any numpy operations.

You can still use the processing in the forward method, but note that you would use the CPU (since numpy cannot use the GPU) and would then add synchronizations by transferring the tensor to the CPU first before calling numpy.
That being said, functionality-wise your approach should work, the performance could potentially be improved if you could stay on the GPU and use PyTorch methods to process the target.

1 Like

Thank you very much. That is very clear! Will see how to improve it eventually.