How to create a Unet model with target has 3 channels (RGB)

Hi folks,
I’ve just done the unet model for multi-classes (5) segmentation with targets are gray images (1 channel). I would like to develop the model to targets with 3 channels (RGB), but not success yet. Here are some snippet of the 1 channel targets:

class myDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir) 

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))       
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        # mask = np.array(Image.open(mask_path).convert("RGB"))

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        
        return image, mask

What transform function does is resize(H,W), rotation, normalization & convert to tensor

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0 #from [0,1] similar as F1 score
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            print('x {}, y {}'.format(x.shape,y.shape))
            y = y.to(device).unsqueeze(1) # y torch.Size([1, 180, 100, 3])
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float() #preds torch.Size([1, 5, 180, 100])
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8

I already tried with mask = np.array(Image.open(mask_path).convert("RGB")) but it raised a conflict of tensor shapes preds torch.Size([1, 5, 180, 100]) vs y torch.Size([1, 180, 100, 3]) in check_accuracy function.

I used the training algo as below:

    model = UNET(in_channels=3, out_channels=5).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

This is the traceback error:

Traceback (most recent call last)
e:\04_Educations\MachineLearning\ML_Online\Pytorch\petroleum\train.py in <cell line: 109>()
    123         save_predictions_as_imgs(val_loader, model, folder="saved_images/", device=DEVICE)
    126 if __name__ == "__main__":
--> 127     main()

e:\04_Educations\MachineLearning\ML_Online\Pytorch\petroleum\train.py in main()
     102 if LOAD_MODEL:
     103     load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
---> 106 check_accuracy(val_loader, model, device=DEVICE)
     107 scaler = torch.cuda.amp.GradScaler()
     109 for epoch in range(NUM_EPOCHS):

e:\04_Educations\MachineLearning\ML_Online\Pytorch\utils.py in check_accuracy(loader, model, device)
     70 preds = (preds > 0.5).float()
     71 print('check accuracy preds',preds.shape)
---> 72 num_correct += (preds == y).sum()
     73 num_pixels += torch.numel(preds)
     74 dice_score += (2 * (preds * y).sum()) / (
     75     (preds + y).sum() + 1e-8
     76 )

RuntimeError: The size of tensor a (100) must match the size of tensor b (3) at non-singleton dimension 4

Please share any advice, thank you!

PyTorch uses the channels-first layout, so your target should have the shape [1, 3, 180, 100] if you want to use 3 channels.
However, I’m not sure I fully understand the use case, since you won’t be able to compare the predictions for 5 classes with an RGB target in:

num_correct += (preds == y).sum()

Thanks @ptrblck,
num_correct is used to compute accuracy of the training. so if the num_correct += (preds == y).sum() not applicable, how should I revise it to use the target with 3 channels (RGB).
I also tried to convert the input images to single channel image = np.array(Image.open(img_path).convert("L"), dtype=np.float32) but error persists there.
Please let me share you a bit the context of tha: Actually my target is a 5 geology layers model:

That’s why i set out_channel=5. My goal is using Unet model derives feature from the inputs as here
image to predict output as similar to the target as possible. I have about 100 image & mask pairs, I can increase the number of data if needed.
I’m new on this field, hope to get more advice, thank you!

Based on your shared code it seems you are using nn.CrossEntropyLoss which expects the targets to contain class indices in the default case.
If you’ve already created these targets, you could use them to calculate the accuracy by comparing them against preds.
However, since nn.CrossEntropyLoss expects the model to output raw logits, this line of code also looks wrong:

preds = (preds > 0.5).float()

since you are comparing against a probability threshold while preds = torch.argmax(output, dim=1) is usually used for a multi-class segmentation.

Thanks ptrblck, let me dive deeper to that