UNet Multiclass Loss Function Selection

You are correct that there are only 4 channels (RGB + background as black) in my masks. Each channel/class is binary with the black background as zeros. I have my model setup for only 3 output channels.

I do want to keep the multi-label aspect of this classification so that I can analyze each channel independently in a postprocessing step. So it is important to me that the network learns “class independence”, so to speak (although that might be the wrong term). For instance, I want the network to learn there are “blue plate pixels” under the QR code, so that I can analyze the blue channel without worrying about it having a QR code-shaped hole through it. I want to be able to measure the blue plate characteristics regardless of the QR code or any other overlapping label.

The same logic follows for the other channels, where it’s important to have overlapping segments for this application. I hope that makes sense.

With this in mind, I believe BCEWithLogitsLoss is still the way to go based on your earlier description. However, I’m having some trouble setting up the weights with this loss function (see below). I found this post (from you) that offers some explanation but I’m still confused how to get the weights up and running.

# red, blue, green channel weights = 6, 1, 30
loss_func = nn.BCEWithLogitsLoss(
    pos_weight=torch.Tensor([1, 30, 6]).repeat(BATCH_SIZE, 1))

# RuntimeError: The size of tensor a (3) must match the size of
# tensor b (256) at non-singleton dimension 3

Any advice?