Class imbalance with image segmentation

Could you explain your use case a bit more and how you would like to sample the masks?
Usually you would create sample weights, which would assign a certain weight to each sample in the Dataset (image and mask pair). In your description it looks like you would like to sample pixels from your mask? Is that correct or do I misunderstand your use case?

Also, for reference: this answer might belong into this thread.