Conv3d pytorch depth reduced in output

import torch
import torch.nn as nn

class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()

        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2)

        # Middle layers
        self.middle = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.MaxPool3d(kernel_size=2, stride=2)

        # Decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.ConvTranspose3d(64, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        x1 = self.encoder(x)

        # Middle
        x2 = self.middle(x1)

        # Decoder
        x3 = self.decoder(x2)

        return x3

# Define the number of classes
num_classes = 3

# Create an instance of the UNet3D model
model = UNet3D(in_channels=1, out_channels=num_classes)

# Test with a dummy input
dummy_input = torch.randn(8, 1, 200, 200, 17)  # Batch size of 1, single-channel 64x64x64 volume


output = model(dummy_input)


I have a tensor of shape (batch_size, channel, height, width, depth) [8, 1, 200, 200, 17] ideally when it is sent through the UNet we should get back the same dimensions. Can I know whether I’m making any mistakes here because I’m getting [8, 1, 200, 200, 16]. ie- 1 image from the depth dimension short

The thing is that the expansion part will have “rounding” of the sizes: your encoder halfs sizes twice with rounding down and then your decoder doubles it. This is why the ConvTranspose has an output_padding parameter, even if that might not be that useful to you here.

Best regards


1 Like

Is there anyway I could obtain the output tensor as the same spatial dimensions as the input tensor. Obviously I could slice along the tensors but I don’t want to do that since I would be passing the tensor for backprop

So obvious ways would be to

  • pad somewhere (either on the way down or on the way up), and crop at the end,
  • to use interpolate instead of deconvolution,

These come with drawbacks and advantages.
Another (highly efficient computationally, unless you think you loose too much information) way could be to feed in images that are cropped to be divisible by an appropriate power of 2.

Best regards


Thanks a bunch for those suggestions Thomas. Much appreciated :slight_smile: