working on a large UNET3D that worked fine with input samples of shape (batch_n,1,64,64,32).
now i switch to a new dataset with different sample shape - (batch_n,1,79,95,68)
getting an error during concatenation of skip connection caused by dimension mismatch.
i’m pretty sure the root of the error is in the fact there’s a downsample (by setting stride=2) of an odd nubmer and then upsample which causes a discrepency. when shape was 64,64,32 the division by 2 was always coordinated with the upsample by 2.
so far i couldn’t find a solution so hoping someone might have come across a similar issue in the past and can share an insight.
p.s
i tried simply turning the odd shape into an even one, like (batch_n,1,78,94,68) but still getting the same error. i guess i had luck with the original shape that is dividing nicely by 2
here’s the code:
import torch.nn as nn
import torch
from torchsummary import summary
import torchsummaryX
from lib.medzoo.BaseModelClass import BaseModel
class UNet3D_inpainting(BaseModel):
"""
Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650
with changes to fit volume inpainting
"""
def __init__(self, in_channels, n_classes, args, base_n_filter=8):
super(UNet3D_inpainting, self).__init__()
self.n_classes = n_classes # no classification, just regression
self.in_channels = in_channels
self.n_subjects = args.subject_num
self.dim = args.dim
self.base_n_filter = base_n_filter
self.lrelu = nn.LeakyReLU()
self.dropout3d = nn.Dropout3d(p=args.dropout[0])
self.upsacle = nn.Upsample(scale_factor=2, mode='nearest')
self.softmax = nn.Softmax(dim=1)
self.context0 = int(args.context_0)
self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1,
bias=False)
self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1,
bias=False)
self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)
self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1,
bias=False)
self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2)
self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2)
self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1,
bias=False)
self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4)
self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4)
self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1,
bias=False)
self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8)
self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8)
self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1,
bias=False)
self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16)
self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16,
self.base_n_filter * 8)
self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0,
bias=False)
self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8)
self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16)
self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0,
bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8,
self.base_n_filter * 4)
self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8)
self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0,
bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4,
self.base_n_filter * 2)
self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4)
self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0,
bias=False)
self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2,
self.base_n_filter)
self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2)
self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0,
bias=False)
self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0,
bias=False)
self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0,
bias=False)
def conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def norm_lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def lrelu_conv(self, feat_in, feat_out):
return nn.Sequential(
nn.LeakyReLU(),
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))
def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
return nn.Sequential(
nn.InstanceNorm3d(feat_in),
nn.LeakyReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
# should be feat_in*2 or feat_in
nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm3d(feat_out),
nn.LeakyReLU())
def forward(self, x, subj):
# Level 1 context pathway
context_0 = x
out = self.conv3d_c1_1(x)
residual_1 = out
out = self.lrelu(out)
out = self.conv3d_c1_2(out)
out = self.dropout3d(out)
out = self.lrelu_conv_c1(out)
# Element Wise Summation
out += residual_1
context_1 = self.lrelu(out)
out = self.inorm3d_c1(out)
out = self.lrelu(out)
# Level 2 context pathway
out = self.conv3d_c2(out)
residual_2 = out
out = self.norm_lrelu_conv_c2(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c2(out)
out += residual_2
out = self.inorm3d_c2(out)
out = self.lrelu(out)
context_2 = out
# Level 3 context pathway
out = self.conv3d_c3(out)
residual_3 = out
out = self.norm_lrelu_conv_c3(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c3(out)
out += residual_3
out = self.inorm3d_c3(out)
out = self.lrelu(out)
context_3 = out
# Level 4 context pathway
out = self.conv3d_c4(out)
residual_4 = out
out = self.norm_lrelu_conv_c4(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c4(out)
out += residual_4
out = self.inorm3d_c4(out)
out = self.lrelu(out)
context_4 = out
# Level 5
out = self.conv3d_c5(out)
residual_5 = out
out = self.norm_lrelu_conv_c5(out)
out = self.dropout3d(out)
out = self.norm_lrelu_conv_c5(out)
out += residual_5
out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)
out = self.conv3d_l0(out)
out = self.inorm3d_l0(out)
out = self.lrelu(out)
# Level 1 localization pathway
out = torch.cat([out, context_4], dim=1)
out = self.conv_norm_lrelu_l1(out)
out = self.conv3d_l1(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)
# Level 2 localization pathway
out = torch.cat([out, context_3], dim=1)
out = self.conv_norm_lrelu_l2(out)
ds2 = out
out = self.conv3d_l2(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)
# Level 3 localization pathway
out = torch.cat([out, context_2], dim=1)
out = self.conv_norm_lrelu_l3(out)
ds3 = out
out = self.conv3d_l3(out)
out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)
#('level 3 - localize - shape ', out.shape)
# Level 4 localization pathway
out = torch.cat([out, context_1], dim=1)
out = self.conv_norm_lrelu_l4(out)
out_pred = self.conv3d_l4(out)
ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)
ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)
out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
seg_layer = out + self.context0 * context_0
return seg_layer