Validation accuracy weirdly dependent on validation batch size

Hi,

I’m doing a medical image segmentation task, and using dice score to validate the performance of my model after every epoch. Everything is going as expected (i.e. training is converging, results look good), except for the fact that my validation set dice score seems to vary lot based on my validation loader batch size.

When using a validation batch size of 1, my average dice score is ~0.5, but when I change the validation batch size to 8, it jumps to about 0.65 after 10 epochs. My understanding is the validation accuracy should be independent of the validation set batch sized used.

My training and validation code is as follows (UNet is my U-Net model):

for epoch in range(epochs):

    epoch_loss = 0
    #Train model
    for i_batch, data in enumerate(dataset_loader, 0):

        unet.train()
        img_batch = data[0].to(device)
        true_masks_batch = data[1].to(device)

        unet.zero_grad()
        mask_pred = unet(img_batch)

        loss = DiceDistanceCriterion(mask_pred, true_masks_batch)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i_batch % 50 == 0:

            train_dice_score = dice_score(mask_pred, true_masks_batch)
            print(train_dice_score)

    #Validate after each epoch
    with torch.no_grad():
        val_dice_score = 0
        epoch_val_loss = 0
        for val_batch, val_sample in enumerate(val_dataset_loader,0):

            unet.eval()
            val_img = val_sample[0].to(device)
            val_mask = val_sample[1].to(device)
            pred_val_mask = unet(val_img)

            val_loss = DiceDistanceCriterion(pred_val_mask, val_mask)
            epoch_val_loss += val_loss.item()

            val_dice_score += dice_score(pred_val_mask, val_mask)

            if val_batch % 50 == 0:
                print(val_dice_score/(val_batch+1))

    epoch_loss = epoch_loss/(i_batch + 1)
    epoch_val_loss = epoch_val_loss/(val_batch + 1)
    val_dice_score = val_dice_score/(val_batch + 1)

    print('Epoch: {}/{} --- Val. Loss: {}'.format(epoch, epochs, epoch_val_loss))

    scheduler.step(epoch_loss)

And dice score is defined as:

def dice_score(input, target):
    eps = 1e-8

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()

    return ((2. * intersection + eps) /
            (iflat.sum() + tflat.sum() + eps))

But my validation dice score looks like this, as you can see it changes largely based on batch size.

From what I can see, your evaluation code has a mistake: you use pred_val_mask to compute the dice score, however, this does not seem to be a binary tensor since you also use the output of the model to compute the loss.

Good point, I will change that and re-validate, but I don’t think that would be causing the problem of validation dice score changing with batch size?

No it should not be changing the results, even though I am not sure what it was actually measuring since I would expect that val_mask is a LongTensor but the dice computation function would not work in that case.

So I added this line in my validation code (after the loss function):

pred_val_mask = (pred_val_mask>0.5).float()

But as you can see the discrepancies in dice score still exist when different validation batch sizes are used.:

@Latope2-150 Yes val_mask is a tensor of 1s and 0s representing the segmentation mask. Without the binary masking the dice_loss would have been measuring the ‘soft’ dice score instead of a binary dice score.

Do you have a repo with all the code you are using for this? (With model class, dataset class, etc). I can’t seem to find a problem in this part of the code.

Is there any random function in your model class that is not switched of when calling eval ?

I don’t have a public repo for this, but I’m happy to share those classes. I’m using the Unet model directly from this repo without changing it: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py and https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

My dataset class:

class medImDataset(Dataset):
    def __init__(self, image_path, mask_path, dist_path, contrast = True, hflip=True, train=True):
        self.image_path = image_path
        self.mask_path = mask_path
        self.dist_path = dist_path
        self.image_files = os.listdir(self.image_path)
        self.mask_files = os.listdir(self.mask_path)
        self.dist_files = os.listdir(self.dist_path)
        self.contrast = contrast
        self.hflip = hflip

    def transform(self, image, mask, dist_map, contrast, hflip):
        resize = transforms.Resize(size=(img_height,img_width))
        image = resize(image)
        mask = resize(mask)
        dist_map = resize(dist_map)

        image = TF.to_grayscale(image)
        mask = TF.to_grayscale(mask)
        dist_map = TF.to_grayscale(dist_map)

        if hflip is True:
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
                dist_map = TF.hflip(dist_map)

        if contrast is True:
            if random.random() > 0.5:
                image = TF.adjust_contrast(image, contrast_factor=((random.random() * 1.5) + 0.25))

        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        dist_map = TF.to_tensor(dist_map)

        image = TF.normalize(image, mean=[0], std=[1])

        return image, mask, dist_map

    def __getitem__(self, index):
        image_name = self.image_files[index]
        mask_name = self.mask_files[index]
        dist_name = self.dist_files[index]
        image = Image.open(os.path.join(self.image_path, image_name))
        mask = Image.open(os.path.join(self.mask_path, mask_name))
        dist_map = Image.open(os.path.join(self.dist_path, dist_name))
        x, y, z = self.transform(image, mask, dist_map, contrast=self.contrast, hflip=self.hflip)
        return x, y, z

    def __len__(self):
        return len(self.image_files)

My datasets, dataloaders for training and validation:

train_dataset = medImDataset(image_path=train_img_path, mask_path=train_mask_path, dist_path=train_dist_path)

batch_size = 8

train_dataset_loader = DataLoader(
    train_dataset,
    batch_size=batch_size, shuffle=True,
    num_workers=8
)

val_dataset = medImDataset(image_path=val_img_path, mask_path=val_mask_path, dist_path=val_dist_path, contrast=False, hflip=False)

val_batch_size = 8 #I vary this between 1-8, which gives me different validation dice scores

val_dataset_loader = DataLoader(
    us_val_dataset,
    batch_size=val_batch_size, shuffle=False,
    num_workers=8
)

Hmm, I don’t see anything wrong with these pieces of code.

For the dice score, I would advise having a dice_score function that returns the dice for each element of the batch, and then doing a mean on the concatenations of all the dices.

val_dices = []
for val_batch, val_sample in enumerate(val_dataset_loader,0):
    ...
    val_dices.append(dice_score(pred_val_mask, val_mask)) # dice score return a tensor of length batch_size
    ...
val_dice_score = torch.cat(val_dices, 0).mean().item()

If your dataset length is not a multiple of your batch_size, computing the dice per batch is gonna cause some differences when summing.

On a different note, it is preferable that a model returns non-normalized activations (before sigmoid or softmax) as, depending on your loss (not the case in your code), this can cause numerical instability. This mostly happens when the loss uses a log.

Oh okay, I think the problem might be that my dice_score function returns a single dice score for the batch of images instead of a score for each image in the batch. I (incorrectly?) assumed that the dice of the batch is the same as the mean of the dice scores of each image.

In that case, what is the correct way of implementing my DiceLoss class? Because for that I do the same thing and calculate a single dice loss for the entire batch instead of each image in the batch.

Indeed, the dice of the batch is not the same as the mean of the batch of each image. This comes from the fact that the size of your positive regions is not the same in each image (I assume).

For both your loss and evaluation, you could actually use the same function:

def dice_score(input: torch.Tensor, target: torch.Tensor, eps: float=1e-8) -> torch.Tensor:
    """input and target should be FloatTensor where the first dimension is the batch."""
    intersection = (input * target).flatten(1).sum(1)
    sum_both = input.flatten(1).sum(1) + target.flatten(1).sum(1)
    return (2 * intersection + eps) / (sum_both + eps)

This works only if you consider one class.

Okay, thanks a lot for your help - I’m going to rewrite my loss and score functions.

Related question in case you know: Is the generalized dice loss the same as the dice loss when there is only one class? I recall the generalized dice loss adds weighting to each score based on the positive samples from each class.

Well, you can easily test it yourself, but the answer is yes. Although background vs foreground can be considered as two classes.
For a one class scenario:
DL(p, r) = 1 - 2 * sum( p * r ) / (sum( p ) + sum( r ))
GDL(p, r) = 1 - 2 * w * sum( p * r ) / (w * (sum( p ) + sum( r ))) = 1 - 2 * sum( p * r ) / (sum( p ) + sum( r ))

1 Like

Thanks for your help, I fixed the issue by measuring dice score for each image in the validation batch and then averaging, instead of measuring dice once on the entire batch.

On the other hand, I found that training converged much faster and was much more stable when using my original ‘batch dice loss’ instead (i.e. same as what is posted here: https://github.com/pytorch/pytorch/issues/1249#issuecomment-305088398)

I am guessing you are trying to segment small objects? In that case, you should consider computing your dice loss on both the background and foreground as:

loss = 1 - (dice_score(mask_pred, true_masks) + dice_score(1 - mask_pred, 1 - true_masks)) / 2

You could also consider the GDL or even the Boundary Loss if you are indeed segmenting small objects.

Thanks for the tip, I will try that.

I am indeed implementing a boundary loss like in the MIDL 2019 paper, but I haven’t yet seen it have an overall improvement compared to using just Dice.

I guess it depends on your application. Usually, boundary loss is not used alone but in complement to GDL.