Dilation Unet Sizes

I am trying to create a 3D unet for medical image segmentation with dilation layers to give me a the required receptive field size i need without making the model too heavy. I am having a lot of trouble fixing the sizes along the encoder and decoder to match. My input is [bs,4,96,96,64] and output [bs,5,96,96,64] i.e each pixel has 5 possible class possibilities. As you see I have my down block and upblock defined and I max pool two times in the encoder part. I have shown the sizes and error after the code below:

The error is in the first upblock as that interpolates a size of 12 while the corrposing block is size 13 in the 3rd dimension. Can someone please help me make it symmetrical or atleast workable?

class UNet_down_block(torch.nn.Module):
    def __init__(self, input_channel, output_channel, down_size):
        super(UNet_down_block, self).__init__()
        self.conv1 = torch.nn.Conv3d(input_channel, output_channel, 3, padding=1,dilation=1)
        self.bn1 = torch.nn.BatchNorm3d(output_channel)
        self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
        self.bn2 = torch.nn.BatchNorm3d(output_channel)
        self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1,dilation=2)
        self.bn3 = torch.nn.BatchNorm3d(output_channel)
        self.max_pool = torch.nn.MaxPool3d(2, 2)
        self.relu = torch.nn.ELU()
        self.down_size = down_size

    def forward(self, x):
        if self.down_size:
            x = self.max_pool(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class UNet_up_block(torch.nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel):
        super(UNet_up_block, self).__init__()
#         self.up_sampling = torch.nn.functional.interpolate(scale_factor=2, mode='trilinear')
        self.conv1 = torch.nn.Conv3d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(output_channel)
        self.conv2 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm3d(output_channel)
        self.conv3 = torch.nn.Conv3d(output_channel, output_channel, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm3d(output_channel)
        self.relu = torch.nn.ELU()

    def forward(self, prev_feature_map, x):

        x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')

        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.down_block1 = UNet_down_block(4, 24, False)
        self.down_block2 = UNet_down_block(24, 72, True)
        self.down_block3 = UNet_down_block(72, 148, True)
        self.down_block4 = UNet_down_block(148, 224, False)
        self.max_pool = torch.nn.MaxPool3d(2, 2)



        self.mid_conv1 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(224)
        self.mid_conv2 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm3d(224)
        self.mid_conv3 = torch.nn.Conv3d(224, 224, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm3d(224)


        self.up_block1 = UNet_up_block(224, 224, 148)
        self.up_block2 = UNet_up_block(148, 148, 72)
        self.up_block3 = UNet_up_block(72, 72, 24)
        self.up_block4 = UNet_up_block(24, 24, 8)

        self.last_conv1 = torch.nn.Conv3d(8, 4, 3, padding=1)
        self.last_bn = torch.nn.BatchNorm3d(4)
        self.last_conv2 = torch.nn.Conv3d(4, 1, 1, padding=0)
        self.relu = torch.nn.ELU()
        self.last_conv3 = torch.nn.Conv3d(1, 1, 1, padding=0)
        self.relu = torch.nn.ELU()

        self.conv1f=torch.nn.Conv2d(1, 5, 3,padding=1)
        self.conv2f=torch.nn.Conv2d(5, 5, 3,padding=1)
        self.conv3f=torch.nn.Conv2d(5, 5, 3,padding=1)

    def forward(self, x):
        print('input unet',x.size())
        self.x1 = self.down_block1(x)
        print("Block 1 shape:",self.x1.size())
        self.x2 = self.down_block2(self.x1)
        if self.x2.size()[2]==49:                                         ###*********************************** ifffff        if self.x2.size()[2]==49:
            self.x2=self.x2[:,:,1:,1:,:]


        print("Block 2 shape:",self.x2.size())
        self.x3 = self.down_block3(self.x2)
        print("Block 3 shape:",self.x3.size())


        self.x4 = self.down_block4(self.x3)
        print("Block 4 shape:",self.x4.size())


        self.xmid=self.max_pool(self.x4)
        self.xmid = self.relu(self.bn1(self.mid_conv1(self.xmid)))
        self.xmid = self.relu(self.bn2(self.mid_conv2(self.xmid)))
        self.xmid = self.relu(self.bn3(self.mid_conv3(self.xmid)))
        print("Block Mid shape:",self.xmid.size())



        x = self.up_block1(self.x4, self.xmid)
#         print("BlockU 1 shape:",x.size())
        x = self.up_block2(self.x3, x)
        print("BlockU 2 shape:",x.size())

        x = self.up_block3(self.x2, x)
        print("BlockU 3 shape:",x.size())

        if self.x1.size()[2]==98:                     ###*********************************** ifffff
            self.x1=self.x1[:,:,1:-1,1:-1,:]
#             print('chan98',self.x1.size())

        x = self.up_block4(self.x1, x)
        print("BlockU 4 shape:",x.size())


        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)  # of size [batch_size,1,h,w,depth] or [bs, modalities(1) ,96 ,96 , 64]

        x=x.view(batch_size,1,-1,64)
#         x=x.squeeze(1)  
#         print('input convf',x.size())
        conv=self.relu(self.conv1f(x))
        conv=self.relu(self.conv2f(conv))
        conv=self.conv3f(conv)


        try:
            conv=conv.view(batch_size,5,96,96,64)
        except:
            conv=conv.view(batch_size_val,5,96,96,64)
#         print('unet output',conv.size())

        return(conv)

Here is the the output:

input unet torch.Size([1, 4, 96, 96, 64])
Block 1 shape: torch.Size([1, 24, 92, 92, 60])
Block 2 shape: torch.Size([1, 72, 42, 42, 26])
Block 3 shape: torch.Size([1, 148, 17, 17, 9])
Block 4 shape: torch.Size([1, 224, 13, 13, 5])
Block Mid shape: torch.Size([1, 224, 6, 6, 2]) 

Error:
     x = self.up_block1(self.x4, self.xmid)
        111 #         print("BlockU 1 shape:",x.size())
        112         x = self.up_block2(self.x3, x)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

<ipython-input-5-cbcdda025480> in forward(self, prev_feature_map, x)
     39         x = torch.nn.functional.interpolate(x,scale_factor=2, mode='trilinear')
     40 
---> 41         x = torch.cat((x, prev_feature_map), dim=1)
     42         x = self.relu(self.bn1(self.conv1(x)))
     43         x = self.relu(self.bn2(self.conv2(x)))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 12 and 13 in dimension 2 at /opt/conda/conda-bld/pytorch_1535491974311/work/aten/src/TH/generic/THTensorMath.cpp:3616

It’s a bit hard to debug your code, but I think the easiest way currently would be to specify your size for torch.nn.functional.interpolate. Currently you are using scale_factor=2. Alternatively you could pass the desired size and use size=size in the Unet_up_block.

Thanks! I will try this. That main issue I was having was due to ConvTranspose3d: https://github.com/pytorch/pytorch/issues/2119#issuecomment-430311611

It is reading my 5D inputs as 2D. Maybe you can offer a quick fix for it. Transpose will be more accurate than interpolation for my model definitely.