Semantic segmentation model (UNet) doesnt learn

I tried training a UNet model written in pytorch but i cant seem to make it work. I tried training on a single image (the dataset is Carvana) for 500 epochs but the output is pure black. Any help would be appreciated.

Here is the link to my Kaggle kernel: Carvana-Pytorch

Did you make sure the target tensor do represent the valid segmentation mask?
Your current code is:

class MyDataset(Dataset):
    def __init__(self, image_paths, mask_paths, train=True):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
    def transforms(self, image, mask):
        image = transforms.Resize(size=(256, 256))(image)
        mask = transforms.Resize(size=(256, 256))(mask)
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        return image, mask
    def __getitem__(self, index):
        image =[index])
        mask =[index])
        x, y = self.transforms(image, mask)
        return x, y
    def __len__(self):
        return len(self.image_paths)

Iā€™m a little bit concerned about the transformation applied on the target.
If you are loading a valid segmentation mask containing only the class indices (in your case it should be zeros for background and ones for the car), the Resize and ToTensor transformations might mess up the target values.

For Resize you should use PIL.Image.NEAREST as the interpolation, since PIL.Image.BILINEAR is the default one and will interpolate to invalid class indices at the borders.
Depending which image type you are passing to ToTensor, this might also normalize the image, which might further cause trouble.

Could you print some target tensors and just make a sanity check, that they only contain your desired class indices?

1 Like

I resized the images using PIL.Image.NEAREST and the result seems the same, i think the ToTensor method normalizez the image becouse the values of the image are betweem -1 and 1. I have one output channel, is this a problem?

I also checked if the masks are matching, here is an updated version of the notebook: Updated

I multiplied the values of image by 255 and it workes now, thanks @ptrblck !