Test U-Net on one image (image-segmentation)

Hello,

I trained the U-Net for several images from a datasef which were converted into pytorch-tensors in an ID_list.txt.
The model works fine and even the evaluation on the other images looks good.

My question now is how I can apply the state-dict of the model that I trained and saved to one single (external) 256x256 image ? How does this single 256x256 image need to be converted to go through the U-net and be able to use with the saved state_dict?

Hy @hfdp, what I understood from your above explanation is that how can you use your trained model for a single image.
Considering a trained model and a image.

# Just set the batch size to 1.
#The model expects batch size to be inserted.
image = image.reshape(1,channels,height,width)
output = model(image)

Hope this helps

Hey @Usama_Hasan,

I still have a question though; the image I have is a float64 with size (256,256). When I plot it it looks fine.

When loading the model I am doing this:

forward_model.load_state_dict(torch.load(saved_models_path + '\\name_model.pt'))
forward_model.eval()
output = forward_model(image)

How should I define and reshape the image then since the image is already at the right size of 256x256?

You don’t have to reshape it some other dim, just add the batch size which in your case would be 1. Just do something like this

image = torch.tenor(image)
image = image.view(1,No_channels,256,256)

Hey @Usama_Hasan
One more question though! I am using the standard U-Net I found on the internet and it isn’t working in this case due to sth in the architecture. Should I change something in this in order to make the one image I mentioned earlier work? And if so, what? :slight_smile:

When removing the permute line I get this error: **Expected 4-dimensional input for 4-dimensional weight [32, 1, 3, 3], but got 2-dimensional input of size [256, 256] instead**

from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        x = x.unsqueeze(0).permute(1, 0, 2, 3)
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

So Pytorch expects input with channel first format so you need to reshape your input as I mentioned earlier to be (batch_size,no_channel,height,width)