Imbalance in training data

What you’re doing is a semantic segmentation task. And it’s important to understand that this is effectively a classification task.

With that out of the way, we just need to calculate weights to pass into the cross_entropy_loss function. This post addresses how to do that with code: