Hi, I am having issues with Dice Loss and Pytorch Ignite. I am trying to reproduce the result of Ternausnet using dice loss but my gradients keep being zero and loss just does not improve or shows very strange results (negative, nan, etc). I am not sure where to look for a possible source of the issue. Below is the code for DiceLoss:
from torch import nn
from torch.nn import functional as F
# from catalyst.contrib.nn import DiceLoss
import torch
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, targets, eps=1, threshold=None):
# comment out if your model contains a sigmoid or
# equivalent activation layer
proba = torch.sigmoid(logits)
proba = proba.view(proba.shape[0], 1, -1)
targets = targets.view(targets.shape[0], 1, -1)
if threshold:
proba = (proba > threshold).float()
# flatten label and prediction tensors
intersection = torch.sum(proba * targets, dim=1)
summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
dice = (2.0 * intersection + eps) / (summation + eps)
# print(intersection, summation, dice)
return (1 - dice).mean()
and here is the model (unet11 backbone code taken from ternausnet):
import pytorch_lightning as pl
import torch
from torch import nn
from torchvision import models
from carvana_unet.utils import DiceLoss
def conv3x3(in_: int, out: int) -> nn.Module:
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_: int, out: int) -> None:
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(
self, in_channels: int, middle_channels: int, out_channels: int
) -> None:
super().__init__()
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(
middle_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
),
nn.ReLU(inplace=True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)
class UNet11Lightning(pl.LightningModule):
def __init__(
self, num_filters: int = 32, pretrained: bool = True, loss_fn=DiceLoss
) -> None:
"""
Args:
num_filters:
pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super().__init__()
self.loss_fn = loss_fn()
self.pool = nn.MaxPool2d(2, 2)
self.encoder = models.vgg11(pretrained=pretrained).features
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(
num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8
)
self.dec5 = DecoderBlock(
num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8
)
self.dec4 = DecoderBlock(
num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4
)
self.dec3 = DecoderBlock(
num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2
)
self.dec2 = DecoderBlock(
num_filters * (4 + 2), num_filters * 2 * 2, num_filters
)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, 1, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(dec1)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch["features"].cuda(), batch["target"].cuda()
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
result = pl.TrainResult(loss)
result.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, y = batch["features"].cuda(), batch["target"].cuda()
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log("val_loss", loss, prog_bar=True, on_step=True)
tensorboard_log = {"val_loss": loss}
return {"loss": loss, "log": tensorboard_log}
def test_step(self, batch, batch_idx):
x = batch["features"].cuda()
y_hat = self(x)
return torch.nn.functional.sigmoid(y_hat)
can anybody help me find the source of the issue or point me into the direction of where to look? I have tried other losses (like BCE with logits which was decreasing into -inf) and nothing seems to work.