Change the dimensions of the UNET network output

As I know, the UNET network predicts the outputs with the same size as the input. How can I change the output of the UNET to be in different sizes (I want certain dimensions for the height and the width to be compatible with my ground truth images)?
please, how can I do this? I want the change to be inside the network? not after the prediction? since when I did the resizing on the prediction image when I load the model again for the testing process, it give a mismatching error.

Your description is a little chaotic for me, sorry. Please try to describe your problem better.

Two ideas:

  • you can resize the output from the U-NET to yours new dimensions - independently from the PyTorch
  • you can modify last layers form the U-NET to return specific dimensions. However, to help you with this we need code you work with and furthers explanations.

Thank you so much for your help. I am sorry for that I am still new in this area.

I changed the last layer for Unet to be 1/5 from the input image as in the following code:


class BaseConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding,

                 stride):

        super(BaseConv, self).__init__()

        self.act = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,

                               stride)

        

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,

                               padding, stride)

    def forward(self, x):

        x = self.act(self.conv1(x))

        x = self.act(self.conv2(x))

        return x

class DownConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, padding,

                 stride):

        super(DownConv, self).__init__()

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv_block = BaseConv(in_channels, out_channels, kernel_size,

                                   padding, stride)

    def forward(self, x):

        x = self.pool1(x)

        x = self.conv_block(x)

        return x

class UpConv(nn.Module):

    def __init__(self, in_channels, in_channels_skip, out_channels,

                 kernel_size, padding, stride):

        super(UpConv, self).__init__()

        self.conv_trans1 = nn.ConvTranspose2d(

            in_channels, in_channels, kernel_size=2, padding=0, stride=2)

        self.conv_block = BaseConv(

            in_channels=in_channels + in_channels_skip,

            out_channels=out_channels,

            kernel_size=kernel_size,

            padding=padding,

            stride=stride)

    def forward(self, x, x_skip):

        x = self.conv_trans1(x)

        x = torch.cat((x, x_skip), dim=1)

        x = self.conv_block(x)

        return x

class UNet(nn.Module):

    def __init__(self, in_channels, out_channels, n_class, kernel_size,

                 padding, stride):

        super(UNet, self).__init__()

        self.init_conv = BaseConv(in_channels, out_channels, kernel_size,

                                  padding, stride)

        self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,

                              padding, stride)

        self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,

                              padding, stride)

        self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,

                              padding, stride)

        self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,

                          kernel_size, padding, stride)

        self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,

                          kernel_size, padding, stride)

        self.up1 = UpConv(2 * out_channels, out_channels, out_channels,

                          kernel_size, padding, stride)

        self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride)

    def forward(self, x):

        # Encoder

        x = self.init_conv(x)

        x1 = self.down1(x)

        x2 = self.down2(x1)

        x3 = self.down3(x2)

        # Decoder

        x_up = self.up3(x3, x2)

        x_up = self.up2(x_up, x1)

        x_up = self.up1(x_up, x)

        x_up = nn.functional.interpolate(x_up, scale_factor=0.5, mode='bilinear', align_corners=True)

        # x_out = F.log_softmax(self.out(x_up), 1)

        return x_up

and I trained the UNET model using images with size [1,3,256,256], I don’t change any thing just save the model. After the training process, I load the model as in the following:

model = UNet(in_channels=3,
                      out_channels=1,
                      n_class=1,
                      kernel_size=3,
                      padding=1,
                      stride=1)

model.load_state_dict(torch.load(model_param_path), strict=False)

but I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-f070578da7ca> in <module>
      1 model = model_name.to(device)
      2 
----> 3 model.load_state_dict(torch.load(model_param_path), strict=False)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1043         if len(error_msgs) > 0:
   1044             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1045                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1046         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1047 

RuntimeError: Error(s) in loading state_dict for UNet:
	size mismatch for init_conv.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
	size mismatch for init_conv.conv1.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for init_conv.conv2.weight: copying a param with shape torch.Size([1, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for init_conv.conv2.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for down1.conv_block.conv1.weight: copying a param with shape torch.Size([2, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
	size mismatch for down1.conv_block.conv1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down1.conv_block.conv2.weight: copying a param with shape torch.Size([2, 2, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for down1.conv_block.conv2.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for down2.conv_block.conv1.weight: copying a param with shape torch.Size([4, 2, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for down2.conv_block.conv1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for down2.conv_block.conv2.weight: copying a param with shape torch.Size([4, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for down2.conv_block.conv2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for down3.conv_block.conv1.weight: copying a param with shape torch.Size([8, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for down3.conv_block.conv1.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for down3.conv_block.conv2.weight: copying a param with shape torch.Size([8, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for down3.conv_block.conv2.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for up3.conv_trans1.weight: copying a param with shape torch.Size([8, 8, 2, 2]) from checkpoint, the shape in current model is torch.Size([512, 512, 2, 2]).
	size mismatch for up3.conv_trans1.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for up3.conv_block.conv1.weight: copying a param with shape torch.Size([4, 12, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 768, 3, 3]).
	size mismatch for up3.conv_block.conv1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for up3.conv_block.conv2.weight: copying a param with shape torch.Size([4, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for up3.conv_block.conv2.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for up2.conv_trans1.weight: copying a param with shape torch.Size([4, 4, 2, 2]) from checkpoint, the shape in current model is torch.Size([256, 256, 2, 2]).
	size mismatch for up2.conv_trans1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for up2.conv_block.conv1.weight: copying a param with shape torch.Size([2, 6, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 384, 3, 3]).
	size mismatch for up2.conv_block.conv1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for up2.conv_block.conv2.weight: copying a param with shape torch.Size([2, 2, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for up2.conv_block.conv2.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for up1.conv_trans1.weight: copying a param with shape torch.Size([2, 2, 2, 2]) from checkpoint, the shape in current model is torch.Size([128, 128, 2, 2]).
	size mismatch for up1.conv_trans1.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for up1.conv_block.conv1.weight: copying a param with shape torch.Size([1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 192, 3, 3]).
	size mismatch for up1.conv_block.conv1.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for up1.conv_block.conv2.weight: copying a param with shape torch.Size([1, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for up1.conv_block.conv2.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for out.weight: copying a param with shape torch.Size([1, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 64, 3, 3]).

Please give us fully-reproducible code - the minimal example with saving and loading the model.

Here is working example:

model = UNet(in_channels=3,
                      out_channels=1,
                      n_class=1,
                      kernel_size=3,
                      padding=1,
                      stride=1)
torch.save(model.state_dict(), "./test.pt")


new_model = model = UNet(in_channels=3,
                      out_channels=1,
                      n_class=1,
                      kernel_size=3,
                      padding=1,
                      stride=1)
new_model.load_state_dict(torch.load("./test.pt"))

x = torch.rand((1, 3, 224, 224))
print(model(x).equal(new_model(x)))

Thank you so much for your help. I solved the problem. The problem was in saving the model. My model works fine now with using

torch.save(model.state_dict(), PATH)

but when I use a general checkpoint for inference using

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

the error of size mismatching and missing keys occur.