Loss function for semantic segmentation in high resolution images, using patching/tiling strategy

Hi, i am doing binary semantic segmentation on large scale images 6000x6000. My data is highly skewed meaning there are only very few objects in the image. As the image has a high resolution I have the need of patching/tiling (of max 512x512 (memory limits)), the issue is when training, as I do random cropping, alot of the corresponding masks are empty/all black. Therefore I need some loss function which can handle alot of masks only containing 0, and every now and then, but most importantly learn from the masks wherein there is acutal content. See images below for reference.
all suggestions are very welcome :slight_smile:

please note this is not the actual images but mimics my issue.

Hi Jona!

Consider using a U-Net as your model (if you’'re not already). U-Net is
designed for performing semantic segmentation on large images, and
generally works well. (The sample image you posted appears to be
appropriate for U-Net.)

U-Net – as a fully-convolutional network – accepts images of arbitrary size,
and with its “tiling strategy,” can accept images of arbitrary size, regardless
of (reasonable) memory limitations.

It has a finite “field of view” (which can be either a good thing or a bad thing),
so, for example, the data in an input image’s upper left corner has no effect
on the predictions made for the lower right corner (as long as the two
corners are outside of one another’s fields of view).

Pytorch’s BCEWithLogitsLoss is likely the most appropriate loss function
for your use case.

You describe a situation where you have imbalanced training data in that
you have many more background pixels than foreground pixels. For your
use case I would suggest a combination of two standard approaches to
dealing with imbalanced training data: weighted sampling of the training
data and class weights in the loss function.

Specifically, I wouldn’t crop randomly. Crop out patches that include a
foreground object or do not. (But otherwise crop randomly.) Crop out
patches that are at least as large as your U-Net’s field of view. From your
description, you will have many more no-foreground-object patches than
foreground-object patches in your cropped-patch training set.

Now when you randomly select training patches in your training loop, bias
that selection so that you select, on average, equal numbers of object and
no-object patches. (WeightedRandomSampler might prove to be helpful
when implementing such sampling.)

Depending on the character of your data, any given object patch may still,
nonetheless, have significantly more background pixels than foreground
pixels, so your training data could still be unbalanced in this regard. If so,
use BCEWithLogitsLoss’s pos_weight constructor argument to weight
the (fewer) foreground pixels more heavily in the loss function.

One note: Precise use of U-Net has some nuances. When used with proper
care, running a large image through U-Net as a bunch of separate patches
using the tiling strategy yields exactly the same predictions as if you had run
the whole image through U-Net all at once. (This is a good thing.)

Best.

K. Frank