U-Net segmentation issue

Hi All,

I have managed to train a unet model on farm data. The input data was resized to (224,224) prior to training. The images (train/val and test) are in RGB and masks in binary format. Nevertheless, the training works as well as validation prediction.
However, when I use it for the testing I get an error. Below is the code :slight_smile:

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchsummary import summary
import cv2

image_dirs = sorted(os.listdir('/kaggle/input/crop-unet/crop_delineation/imgs'))
image_paths = [f'/kaggle/input/crop-unet/crop_delineation/imgs/{image_dir}' for image_dir in image_dirs]

mask_dirs = sorted(os.listdir('/kaggle/input/crop-unet/crop_delineation/masks_filled'))
mask_paths = [f'/kaggle/input/crop-unet/crop_delineation/masks_filled/{mask_dir}' for mask_dir in mask_dirs]

len(image_paths), len(mask_paths)

torchvision.transforms.ToTensor()(Image.open(image_paths[0]).convert('RGB')).shape

image = torch.tensor(np.array(Image.open(image_paths[0]).convert('RGB'))).permute(2, 0, 1)
mask = torch.tensor(np.array(Image.open(mask_paths[0])))

img_seg = torchvision.utils.draw_segmentation_masks(image,
                                                    mask.to(torch.bool),
                                                    alpha=0.3,
                                                    colors='blue')

plt.imshow(img_seg.permute(1, 2, 0))

class FloodAreaDataset(Dataset):
    def __init__(self, image_paths,
                 mask_paths,
                 transform_image=None,
                 transform_mask=None):

        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform_image = transform_image
        self.transform_mask = transform_mask

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')

        mask_path = self.mask_paths[idx]
        mask = Image.open(mask_path)


        if self.transform_image:
            image = self.transform_image(image)
        if self.transform_mask:
            mask = self.transform_mask(mask)
        return image, mask

transform_image = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True),  # Resize to 256x256 pixels

])

transform_mask = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True),  # Resize to match the image size
    lambda x: (x > 0.5).float()  # Ensure mask is binary after resizing and conversion
])

data = FloodAreaDataset(image_paths,
                        mask_paths,
                        transform_image=transform_image,
                        transform_mask=transform_mask)

train_size = int(len(data)*0.8)
val_size = len(data)-train_size

train, val = random_split(data, [train_size, val_size])
train_loader = DataLoader(train, batch_size=32, shuffle=True)
val_loader = DataLoader(val, batch_size=16, shuffle=False)

next(iter(train_loader))[1].shape

class DoubleConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv_1 = nn.Conv2d(in_channel, out_channel,
                                (3, 3), padding='same', bias=False)
        self.batchnorm_1 = nn.BatchNorm2d(out_channel)
        self.relu_1 = nn.ReLU()

        self.conv_2 = nn.Conv2d(out_channel, out_channel,
                                (3, 3), padding='same', bias=False)
        self.batchnorm_2 = nn.BatchNorm2d(out_channel)
        self.relu_2 = nn.ReLU()

    def forward(self, x):
        return self.relu_2(self.batchnorm_2(self.conv_2(self.relu_1(self.batchnorm_1(self.conv_1(x))))))

class U_Net(nn.Module):
    def __init__(self, in_channel, out_channel, hidden=[64, 128, 256, 512]):
        super().__init__()
        down = []
        for h in hidden:
            down.append(DoubleConv(in_channel, h))
            down.append(nn.MaxPool2d(kernel_size=(2, 2),
                                   stride=2))
            in_channel = h
        self.down = nn.ModuleList(down)

        up = []
        in_channel = hidden[-1]*2
        for h in reversed(hidden):
            up.append(nn.ConvTranspose2d(in_channel, h,
                                         (2, 2), stride=2))

            up.append(DoubleConv(h*2, h))
            in_channel = h

        self.up = nn.ModuleList(up)

        self.bottle_neck = nn.Conv2d(hidden[-1], hidden[-1]*2,
                                     (3, 3), padding='same')
        self.end_conv = nn.Conv2d(hidden[0], out_channel,
                                 (1, 1), padding='same')
#         self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        skip_connection = []

        for idx in range(0, len(self.down), 2):
            x = self.down[idx](x)
            skip_connection.append(x)
            x = self.down[idx+1](x)
        x = self.bottle_neck(x)

        skip_connection = skip_connection[::-1]

        for idx in range(0, len(self.up), 2):
            x = self.up[idx](x)
            x = torch.cat([x, skip_connection[idx//2]], dim=1)
            x = self.up[idx+1](x)
        x = self.end_conv(x)
        return x

import torch
from tqdm import tqdm

def train(model, dataloader, loss_fn, optimizer, device):
    model.train()  # Set model to training mode
    total_loss = 0
    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Training Loss: {avg_loss:.4f}")

def validate(model, dataloader, loss_fn, device):
    model.eval()  # Set model to evaluation mode
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation"):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Validation Loss: {avg_loss:.4f}")


model = U_Net(3, 1)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

epochs = 30
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train(model, train_loader, loss_fn, optimizer, device)
    validate(model, val_loader, loss_fn, device)

example_img = next(iter(val_loader))[0][3:5].to(device)
with torch.no_grad():

    pred = model(example_img)
pred = (nn.functional.sigmoid(pred)>0.5).cpu()[0]

fig, ax = plt.subplots(1, 2)
img = (example_img* 255).to(torch.uint8)[0].cpu()
img_mask = torchvision.utils.draw_segmentation_masks(img, pred, alpha=0.6 ,colors='red')
ax[0].imshow(img_mask.permute(1, 2, 0))
ax[1].imshow(img.permute(1, 2, 0))

import os
import json
from PIL import Image
import cv2
import numpy as np
from shapely.geometry import Polygon
from tqdm import tqdm

# Define path to your test folder
test_folder_path = "/kaggle/input/test-cropseg/testU"  # Replace with your actual path

# List to store the predictions with polygon annotations
test_predictions = []

# Set the model to evaluation mode
model.eval()

# Loop through each image in the test folder
for image_filename in os.listdir(test_folder_path):
    # Get the full path of the image
    image_path = os.path.join(test_folder_path, image_filename)

    # Load the image
    image = Image.open(image_path).convert('RGB')

    # Get the image size
    image_width, image_height = image.size

    # Resize the image to match the input size of the model
    image_resized = image.resize((224, 224))

    # Convert the image to a tensor and move it to the device (CPU or GPU)
    image_tensor = transforms.ToTensor()(image_resized).unsqueeze(0).to(device)
    
    # Make prediction using the model
    with torch.no_grad():
        prediction = model(image_tensor)

    # Process the prediction (e.g., apply sigmoid for probability)
    prediction = (nn.functional.sigmoid(prediction) > 0.5).cpu().numpy()[0]

    # Convert the binary image to the appropriate data type for resizing
    prediction_uint8 = (prediction * 255).astype(np.uint8)

    # Resize the predicted mask to match the size of the original test image
    prediction_resized = cv2.resize(prediction_uint8, (image_width, image_height))

    # Convert the resized image to grayscale
    prediction_gray = cv2.cvtColor(prediction_resized, cv2.COLOR_BGR2GRAY)

    # Ensure the data type is correct (8-bit unsigned, single-channel)
    prediction_gray = prediction_gray.astype(np.uint8)

    # Threshold the image if necessary
    _, prediction_thresh = cv2.threshold(prediction_gray, 0, 255, cv2.THRESH_BINARY)

    # Find contours
    contours, _ = cv2.findContours(prediction_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Convert contours to polygon annotations
    polygon_annotations = []
    for contour in contours:
        # Convert contour points to numpy array
        contour_points = contour.reshape(-1, 2)

        # Create a Polygon object from the points
        polygon = Polygon(contour_points)

        # Simplify the polygon (optional)
        polygon = polygon.simplify(tolerance=1)  # Adjust tolerance as needed

        # Append the polygon to annotations
        polygon_annotations.append(polygon)

    # Create an annotation dictionary for the current image
    annotation = {
        "file_name": image_filename,
        "class": "field",  # Replace with your actual class name
        "segmentation": []  # Empty list to store polygon coordinates
    }

    # Extract polygon coordinates and append to annotation
    for polygon in polygon_annotations:
        # Get polygon exterior coordinates
        exterior_coords = np.asarray(polygon.exterior.coords).ravel().tolist()
        annotation["segmentation"].extend(exterior_coords)

    # Append the image annotation to the test predictions list
    test_predictions.append(annotation)

# Create a dictionary with the image annotations
annotations_dict = {"images": test_predictions}

# Save the annotations to a JSON file
output_file = "/kaggle/working/predictions.json"
with open(output_file, "w") as f:
    json.dump(annotations_dict, f)

print(f"Predictions saved to: {output_file}")


I get this error below:

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
Cell In[27], line 49
     46 prediction_resized = cv2.resize(prediction_uint8, (image_width, image_height))
     48 # Convert the resized image to grayscale
---> 49 prediction_gray = cv2.cvtColor(prediction_resized, cv2.COLOR_BGR2GRAY)
     51 # Ensure the data type is correct (8-bit unsigned, single-channel)
     52 prediction_gray = prediction_gray.astype(np.uint8)

error: OpenCV(4.9.0) /io/opencv/modules/imgproc/src/color.simd_helpers.hpp:92: error: (-2:Unspecified error) in function 'cv::impl::{anonymous}::CvtHelper<VScn, VDcn, VDepth, sizePolicy>::CvtHelper(cv::InputArray, cv::OutputArray, int) [with VScn = cv::impl::{anonymous}::Set<3, 4>; VDcn = cv::impl::{anonymous}::Set<1>; VDepth = cv::impl::{anonymous}::Set<0, 2, 5>; cv::impl::{anonymous}::SizePolicy sizePolicy = cv::impl::<unnamed>::NONE; cv::InputArray = const cv::_InputArray&; cv::OutputArray = const cv::_OutputArray&]'
> Invalid number of channels in input image:
>     'VScn::contains(scn)'
> where
>     'scn' is 224

Would anyone be able help me in this matter please.

Thanks & Best Regards
AMJS

The error message seems to point to a memory layout issue where a channels-last input might be expected while the actual input is in channels-first. Check the OpenCV docs to see which layout is expected and permute the input if needed.