import torch
import torch.nn as nn
class UNet3D(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet3D, self).__init__()
# Encoder layers
self.encoder = nn.Sequential(
nn.Conv3d(in_channels, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=2, stride=2)
)
# Middle layers
self.middle = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(128, 128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=2, stride=2)
)
# Decoder layers
self.decoder = nn.Sequential(
nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
nn.ReLU(inplace=True),
nn.Conv3d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose3d(64, out_channels, kernel_size=2, stride=2)
)
def forward(self, x):
# Encoder
x1 = self.encoder(x)
# Middle
x2 = self.middle(x1)
# Decoder
x3 = self.decoder(x2)
return x3
# Define the number of classes
num_classes = 3
# Create an instance of the UNet3D model
model = UNet3D(in_channels=1, out_channels=num_classes)
# Test with a dummy input
dummy_input = torch.randn(8, 1, 200, 200, 17) # Batch size of 1, single-channel 64x64x64 volume
print(dummy_input.shape)
output = model(dummy_input)
print(output.shape)
I have a tensor of shape (batch_size, channel, height, width, depth) [8, 1, 200, 200, 17] ideally when it is sent through the UNet we should get back the same dimensions. Can I know whether I’m making any mistakes here because I’m getting [8, 1, 200, 200, 16]. ie- 1 image from the depth dimension short