UserWarning: Mixed memory format inputs

I have implemented Unet with custom loss function : Dice loss.

while training, I receive the following message :
/usr/local/lib/python3.6/dist-packages/ UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (Triggered internally at /pytorch/aten/src/ATen/native/TensorIterator.cpp:918.)

this warning appear only when calling the Dice loss function. this does not happens when I am calling pytorch function torch.nn.CrossEntropyLoss().

this is my Dice implementation (copied from some webpage):

def calculate_dice_loss(logits,true, eps = 1e-7):

    """Computes the Sørensen–Dice loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the dice loss so we
    return the negated dice loss.
        true: a tensor of shape [Batch size  x 512 x 512].
        logits: a tensor of shape [Batch size x numLabels x 512 x 512]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
        dice_loss: the Sørensen–Dice loss.
    num_classes = logits.shape[1]
    true = true.unsqueeze(1) # now true: a tensor of shape [Batch size x 1 x 512 x 512].
    true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
    true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
    probas = F.softmax(logits, dim = 1)
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
    return (1 - dice_loss)

and this is how I call the loss function while training:

# loss values
    CrossEntropy_Loss = CrossEntropy_criterion(logits, true_labels)
    dice_loss =  calculate_dice_loss(logits,true_labels)

    loss =  dice_loss 

    # Back propagation


Are you using the channels_last memory format in your model?
If not, could you print the shapes of logits and true_labels before passing them into the loss function?