How can I display the predicted mask from UNet?

Hi everyone. I am trying to plot the original image, mask and predicted mask from the UNet model, however, I am getting weird images as my output. This is a method I used before, which worked perfectly but for some weird reason, it is not working anymore…

As shown in the image above, I am not getting an image at all…

Here is how I am calling the dataset:

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class GetData(Dataset):
def init(self, img_path, mask_path, transform=None):
self.img_path = img_path
self.mask_path = mask_path
self.transform = transform

def __len__(self):
    return len(self.img_path)

def __getitem__(self, idx):
    image = np.array(Image.open(self.img_path[idx]).convert('RGB'), 
                     dtype=np.float32)
    mask = np.array(Image.open(self.mask_path[idx]).convert('L'), 
                     dtype=np.float32)
    mask = ((mask/np.max([mask.max(), 1e-8])) > 0.5).astype(np.float32)
    
    
    if self.transform is not None:
        augmentations = self.transform(image=image, mask=mask)
        image = augmentations['image']
        mask = augmentations['mask']

    
    return image, mask

And here is the testing loop which prints to the images:

def testinge_loop (model, loader, device=torch.device(‘cuda’)):

dice = []
model.eval()
with torch.no_grad():
    for x, y in loader:
        
        # split data to image and mask
        image = x
        mask = y
        
        image = image.to(device)
        mask = mask.to(device)
        
        test_outputs = model(image)
        test_outputs = torch.sigmoid(test_outputs)

        dice_metric = dice_metrics(test_outputs, mask)
        dice.append(dice_metric.cpu().numpy())
        
        
        plt.figure(figsize=(12,6), dpi=1200)
        plt.subplot(1,3,1)
        plt.title('Original Image')
        plt.plot(image[0].permute(1,2,0).detach().cpu()[:,:,0])
        plt.subplot(1,3,2)
        plt.title('Ground Truth')
        plt.plot(mask.permute(1,2,0).detach().cpu()[:,:,0])
        plt.subplot(1,3,3)
        plt.title('Predicted Mask')
        plt.plot(test_outputs[0].permute(1,2,0).detach().cpu()[:,:,0])
        break
        
metric = sum(dice)/len(loader)
return metric 

I genuinely have no idea what I am doing wrong…I tried to copy what others did but that seems to cause more errors… Thank you in advance.

Try to use plt.imshow to display images instead of plt.plot as the latter will interpret the array as rows of different signals.