Data augmentation results network not learning at all

Hi all!

I have a modified unet network with bifpn in between encoder and decoder part for segmenation of MS lesions in T1 and FLAIR MRI images. Having implemented augmentation such as rotation, scaling and horizontal flip, my network won’t learn anything. The strange part is when I remove the data augmentation, the network learn and produces segmentation mask. Any suggestions?
Thanks!

Did you make sure to apply the same transformations on the input images as well as the segmentation masks?
To do so I would recommend to use the functional transformation API, as it allows you to reuse the same “random” parameters from each transformation as described here.

Yes, i am currently looping through a dict where flair, t1 and the mask are transformed. I plottef the transformed image vs the original and it looked fine. This confuses me alot. Heres the code for transformation:

import numpy as np
import torch.nn.functional as nn_F
import torchvision.transforms.functional as F
import torchvision.transforms as transform
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import math
from collections.abc import Sequence

class Random_Rotation(object):
    def __init__(self, angle, interpolation_method = transform.InterpolationMode.BILINEAR):
        if not isinstance(angle, (float, int)):
            raise TypeError(f"Please provide a valid angle type (int/float), provided {type(self.angle)}")
        self.angle = angle
        self.interpolation_method = interpolation_method
        
    
    def __call__(self, image_sample):
        current_angle = np.random.uniform(-self.angle, self.angle)

        if isinstance(image_sample, list):
            unpacked = enumerate(image_sample)
        elif isinstance(image_sample, dict):
            unpacked = image_sample.items()
            
        else:
            raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")
        

        for idx,samples in unpacked:

            if not isinstance(samples, torch.Tensor):
                samples = torch.from_numpy(samples)
            
            image_sample[idx] = F.rotate(samples, current_angle, interpolation = self.interpolation_method)
        
        return image_sample


class Random_Horizontal_Flip(object):
    def __init__(self, flip_probability = 0.5):
        if not isinstance(flip_probability, (int,float)):
            raise TypeError(f"Please provide a valid value for the flip probability, you typed: {flip_probability}")
        else:
            pass 
        
        if flip_probability > 1.0 or flip_probability < 0:
            raise ValueError(f"The probability has to be in the interval[0,1], you typed {flip_probability}")
        else:
            pass
        
        self.pvalue = flip_probability
    
    def __call__(self,image_sample):

        if isinstance(image_sample, list):
            unpacked = enumerate(image_sample)
        elif isinstance(image_sample, dict):
            unpacked = image_sample.items()
        else:
            raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")

        if torch.rand(1) < self.pvalue:
            for idx,samples in unpacked:
                
                if not isinstance(samples, torch.Tensor):
                    samples = torch.from_numpy(samples)
                
                image_sample[idx] = F.hflip(samples)
                

        
class Random_Scale(object):
    def __init__(self, scale_factor, interpolation_method = transform.InterpolationMode.BILINEAR):
        if not isinstance(scale_factor, (int,float)):
            raise TypeError(f"Please provide a valid value for the scaling factor, you typed: {scale_factor}")
        else:
            pass

        self.scale_factor = scale_factor
        self.interpolation_method = interpolation_method

    def __call__(self, image_sample):
        
        current_scale = np.random.uniform(1 - self.scale_factor, 1 + self.scale_factor)
        
        
        if isinstance(image_sample, list):
            unpacked = enumerate(image_sample)
            height, width = (image_sample[0]).shape[1:]

        
        elif isinstance(image_sample, dict):
            unpacked = image_sample.items()
            key = list(image_sample.keys())[0]

            height, width = (image_sample[key]).shape[1:]
            
            
        else:
            raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")
        #else:
        #    image_sample = list(image_sample)
        
        if not (height - width) < 1e-6:
            raise ValueError("It seems that the height and the width are not equal")
        else:
            pass

        original_img_size = (height, width)
        height = int(round(height*current_scale))
        width = int(round(width*current_scale))
        new_size = (height,width)
       
        for idx,samples in unpacked:
            
            if not isinstance(samples, torch.Tensor):
                samples = torch.from_numpy(samples)
            
            else:
                pass
            
            image_sample[idx] = F.resize(samples, new_size, self.interpolation_method)

            if current_scale < 1.0:
                diff = (original_img_size[0] - new_size[0])
                left = diff//2
                right = left + diff%2
            
                padding_ = (int(left), int(right),)*2
                
                image_sample[idx] = nn_F.pad(image_sample[idx], padding_,"constant",0)
                

            else:
                x_min = (new_size[0]- original_img_size[0]) // 2
                x_max = x_min + original_img_size[0]
            
                image_sample[idx]= (image_sample[idx])[... , x_min : x_max, x_min : x_max]
               
        return image_sample

Thanks for the update. The transformation look alright assuming that image_sample contains the image and mask tensors. Random_Horizontal_Flip is missing the return statement, but I assume that’s just a copy-paste issue.

Haha yes, its a copy-paste issue. And yes sample_image is a dict with FLAIR, T1 and the mask. Do you have any suggestion why the network is not learning? I mean the dice score is constantly in the interval [0.01,0.02]. As I stated earlier, when I remove the augmentation, its commenting out 1 line of code, its suddenly learn and everything is fine😥

I would recommend to try to overfit a small dataset, e.g. just 10 samples, and check if the model would be able to do so. If not, I would guess there might be another issue in the code, which we haven’t found yet.

2 Likes

Ok, I am going to try that out and update you later​:relaxed::ok_hand:t3:

it does learn but very slow. This is the result after 133 epoch


Phase: train, Epoch:133/200, learning rate: 0.0007537289284477406

recall: 1.0, difference [current - previous]: 0.0
precision: 0.770689845085144, difference [current - previous]: 0.0017197132110595703 ↑
Fscore: 0.780264675617218, difference [current - previous]: 0.0016559362411499023 ↑
dice_score: 0.7828959439802119, difference [current - previous]: 0.0016377398266059773 ↑
diceloss: 0.21710405601978813, difference [current - previous]: -0.001637739826606005 ↓

but I don’t know why it doesn’t work for the complete trainset. After 100 epoch the dice loss is aLmost 1. Any suggestion?