Unsure about how targets in CNN should look like

Hi!
I created my own U-Net model following some walkthrough tutorials, trying to be as close to original paper as possible, making some small changes as nb_classes, padding. Model looks like this:

def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )
    return conv

def crop_img(tensor, target_tensor):
    # square images
    target_size = target_tensor.size()[2]
    if tensor.size()[2] % 2 == 1:
        tensor_size = tensor.size()[2]-1
    else:
        tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

class UNet(nn.Module):
    def __init__(self, nb_classes):
        super(UNet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv(3, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)

        ## transposed convolutions
        self.up_trans_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.up_conv_1 = double_conv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.up_conv_2 = double_conv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.up_conv_3 = double_conv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.up_conv_4 = double_conv(128, 64)

        self.out = nn.Conv2d(64, nb_classes, 1)


    def forward(self, image):
        # encoder part
        # input image
        x1 = self.down_conv_1(image) # this is passed to decoder
        # max pooling
        x2 = self.max_pool_2x2(x1)

        x3 = self.down_conv_2(x2) # this is passed to decoder
        x4 = self.max_pool_2x2(x3)

        x5 = self.down_conv_3(x4) # this is passed to decoder
        x6 = self.max_pool_2x2(x5)

        x7 = self.down_conv_4(x6) # this is passed to decoder
        x8 = self.max_pool_2x2(x7)

        x9 = self.down_conv_5(x8)

        # decoder part
        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))

        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))

        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))

        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))

        x = self.out(x)
        return x

Then before jumping to some larger data, I wanted to learn and understand how everything works, so I created images and masks using this github.

image.shape:   (512, 512, 3)
image.min(), image.max():   0.0 255.0
mask.shape:   (6, 512, 512)
mask.min(), mask.max():   0.0 1.0

So the masks are 6-channel arrays, where each channel represents one color and sometimes channels overlay themselves, as on this example image below.

What I’m doing next is just preparing dummy train and evaluation, please forgive me if this code hurts your eyes :see_no_evil:

    model = UNet(nb_classes=6)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #summary(model, input_size=(3, 128, 128))


    for t in range(99):
        image = load_image(f'data/images/image_{t}.jpg')
        mask = np.load(f'data/masks/mask_{t}.npy')

        # loading images as 3 channel RGB
        image = np.expand_dims(image, 0)
        image = image.transpose(0, 3, 1, 2)
        input_data = torch.tensor(image)

        # loading masks as numpy 6 channel arrays
        mask_data = torch.tensor(mask)

        # argmax to combine all channels?
        y_true = torch.argmax(mask_data, dim=0)
        # adding one BATCH dimension
        y_true = y_true.unsqueeze(0)
        y_pred = model(input_data)

        loss = criterion(y_pred, y_true)
        print(t, loss.item())

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    image = load_image('data/images/image_102.jpg')
    image = np.expand_dims(image, 0)
    image = image.transpose(0, 3, 1, 2)
    input_data = torch.tensor(image)

    pred = model(input_data)
    pred = F.sigmoid(pred)
    print(pred.shape)
    np.save('pred', pred)

Summarizing, I have following questions:

  1. As I understand, in the multi-label segmentations layers can overlay (each pixel can be assigned to multiple objects), where in the multi-class segmentation, each pixel can be assigned only to one object?

So for the multi-class, we can have one channel mask (3x3) which can look like this:
[ 0 1 2 | 1 2 3 | 4 5 5 ]
We would prepare targets as [ batch_size, channel=1, height, width ], please correct me if I am wrong.
Which loss function should I use in this case?

  1. How to proceed with the multi-label segmentation?

  2. Right now the loss after ~ 40 loops drops to around 0.05. I’m using code from the linked github:

loaded_array = np.load('pred.npy')
loaded_array.shape
pred_rgb = [helper.masks_to_colorimg(x) for x in loaded_array]
plt.imshow(pred_rgb[0])

Masks for the 102 image look like this (after 99 loops):
99 loops

What am I doing wrong?