Problems with Dice Loss in Pytorch Ignite

Hi, I am having issues with Dice Loss and Pytorch Ignite. I am trying to reproduce the result of Ternausnet using dice loss but my gradients keep being zero and loss just does not improve or shows very strange results (negative, nan, etc). I am not sure where to look for a possible source of the issue. Below is the code for DiceLoss:

from torch import nn
from torch.nn import functional as F

# from catalyst.contrib.nn import DiceLoss
import torch


class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, targets, eps=1, threshold=None):

        # comment out if your model contains a sigmoid or
        # equivalent activation layer
        proba = torch.sigmoid(logits)
        proba = proba.view(proba.shape[0], 1, -1)
        targets = targets.view(targets.shape[0], 1, -1)
        if threshold:
            proba = (proba > threshold).float()
        # flatten label and prediction tensors

        intersection = torch.sum(proba * targets, dim=1)
        summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
        dice = (2.0 * intersection + eps) / (summation + eps)
        # print(intersection, summation, dice)
        return (1 - dice).mean()

and here is the model (unet11 backbone code taken from ternausnet):

import pytorch_lightning as pl
import torch
from torch import nn
from torchvision import models
from carvana_unet.utils import DiceLoss

def conv3x3(in_: int, out: int) -> nn.Module:
    return nn.Conv2d(in_, out, 3, padding=1)


class ConvRelu(nn.Module):
    def __init__(self, in_: int, out: int) -> None:
        super().__init__()
        self.conv = conv3x3(in_, out)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.activation(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(
        self, in_channels: int, middle_channels: int, out_channels: int
    ) -> None:
        super().__init__()

        self.block = nn.Sequential(
            ConvRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(
                middle_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class UNet11Lightning(pl.LightningModule):
    def __init__(
        self, num_filters: int = 32, pretrained: bool = True, loss_fn=DiceLoss
    ) -> None:
        """
        Args:
            num_filters:
            pretrained:
                False - no pre-trained network is used
                True  - encoder is pre-trained with VGG11
        """
        super().__init__()
        self.loss_fn = loss_fn()
        self.pool = nn.MaxPool2d(2, 2)

        self.encoder = models.vgg11(pretrained=pretrained).features

        self.relu = self.encoder[1]
        self.conv1 = self.encoder[0]
        self.conv2 = self.encoder[3]
        self.conv3s = self.encoder[6]
        self.conv3 = self.encoder[8]
        self.conv4s = self.encoder[11]
        self.conv4 = self.encoder[13]
        self.conv5s = self.encoder[16]
        self.conv5 = self.encoder[18]

        self.center = DecoderBlock(
            num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8
        )
        self.dec5 = DecoderBlock(
            num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8
        )
        self.dec4 = DecoderBlock(
            num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4
        )
        self.dec3 = DecoderBlock(
            num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2
        )
        self.dec2 = DecoderBlock(
            num_filters * (4 + 2), num_filters * 2 * 2, num_filters
        )
        self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)

        self.final = nn.Conv2d(num_filters, 1, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        conv1 = self.relu(self.conv1(x))
        conv2 = self.relu(self.conv2(self.pool(conv1)))
        conv3s = self.relu(self.conv3s(self.pool(conv2)))
        conv3 = self.relu(self.conv3(conv3s))
        conv4s = self.relu(self.conv4s(self.pool(conv3)))
        conv4 = self.relu(self.conv4(conv4s))
        conv5s = self.relu(self.conv5s(self.pool(conv4)))
        conv5 = self.relu(self.conv5(conv5s))

        center = self.center(self.pool(conv5))

        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(torch.cat([dec2, conv1], 1))
        return self.final(dec1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch["features"].cuda(), batch["target"].cuda()
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        result = pl.TrainResult(loss)
        result.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch["features"].cuda(), batch["target"].cuda()
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log("val_loss", loss, prog_bar=True, on_step=True)
        tensorboard_log = {"val_loss": loss}
        return {"loss": loss, "log": tensorboard_log}

    def test_step(self, batch, batch_idx):
        x = batch["features"].cuda()
        y_hat = self(x)
        return torch.nn.functional.sigmoid(y_hat)

can anybody help me find the source of the issue or point me into the direction of where to look? I have tried other losses (like BCE with logits which was decreasing into -inf) and nothing seems to work.

Moving to Ignite category for better visibility. CC @vfdev-5

@notacode checkout pytorch-ignite quick-start on how to use pytorch-ignite without pl.LightningModule :slight_smile: