Hi Jona!
Consider using a U-Net as your model (if you’'re not already). U-Net is
designed for performing semantic segmentation on large images, and
generally works well. (The sample image you posted appears to be
appropriate for U-Net.)
U-Net – as a fully-convolutional network – accepts images of arbitrary size,
and with its “tiling strategy,” can accept images of arbitrary size, regardless
of (reasonable) memory limitations.
It has a finite “field of view” (which can be either a good thing or a bad thing),
so, for example, the data in an input image’s upper left corner has no effect
on the predictions made for the lower right corner (as long as the two
corners are outside of one another’s fields of view).
Pytorch’s BCEWithLogitsLoss
is likely the most appropriate loss function
for your use case.
You describe a situation where you have imbalanced training data in that
you have many more background pixels than foreground pixels. For your
use case I would suggest a combination of two standard approaches to
dealing with imbalanced training data: weighted sampling of the training
data and class weights in the loss function.
Specifically, I wouldn’t crop randomly. Crop out patches that include a
foreground object or do not. (But otherwise crop randomly.) Crop out
patches that are at least as large as your U-Net’s field of view. From your
description, you will have many more no-foreground-object patches than
foreground-object patches in your cropped-patch training set.
Now when you randomly select training patches in your training loop, bias
that selection so that you select, on average, equal numbers of object and
no-object patches. (WeightedRandomSampler
might prove to be helpful
when implementing such sampling.)
Depending on the character of your data, any given object patch may still,
nonetheless, have significantly more background pixels than foreground
pixels, so your training data could still be unbalanced in this regard. If so,
use BCEWithLogitsLoss
’s pos_weight
constructor argument to weight
the (fewer) foreground pixels more heavily in the loss function.
One note: Precise use of U-Net has some nuances. When used with proper
care, running a large image through U-Net as a bunch of separate patches
using the tiling strategy yields exactly the same predictions as if you had run
the whole image through U-Net all at once. (This is a good thing.)
Best.
K. Frank