U-NET - Black Tiles

Hi, everybody!

New pytorcher here. I am trying to implement a binary semantic segmentation approach using UNET.
To get myself started I took some code from this example.

My dataset contains some 250 images and I used rotations to expand that to nearly 1000.

Here is an example of my images and targets:
target mask|537x252

Now when I train the model two odd things happen:

  1. The model achieves high accuracy very quickly, after just a couple of epochs.
  2. All prediction masks are blank. The more epochs the blanker they are.

After a certain amout of reading, I believe the problem is that the targets in my images are fairly small.
The model thus gets a high accuracy for predicting zeroes for everything, thus essentially training itself to paint everything black.

The best solution appears to be:

  1. Way more training epochs
  2. A weighted scoring approach where positive pixels are favoured.

Any thoughts on this analysis and hints on how to achieve the weight scoring would be appreciated, I’m still very new to all this.

Cheers!

Update:

I trained the model for 50 epochs and the resulting predictions were again blanks.

I then assembled a data set in which all the targets made up a large part of the images.
(small tagets potentially being the issue)

After training the new set for 10 epochs I got this result. Not perfect but not terrible.
10(1)

However retraining the same set with the same settings from scratch, delivers this were everything is turned around:
10(2)

Also training for more than 10 epochs again creates blank predictions.

Has anyone encountered this before?

The blank predictions can indeed be produced by a highly imbalanced target distribution in the mask images and you could try to use a weighted loss function to counter this effect.
I’m unsure about the last effect of flipped predictions and then going to a blank prediction again.

Many thanks for the reply.

I have experimented with smaller image tiles 256px instead of 512px, and the results appear to be similar.

I will look into putting together a weighted loss function.

Can you please mention if got the correct results with weighted loss and what was that