Layer size manipulated for 3D image of network taking 2D image

Here is the code, which was originally written for 2D image registration,(recursive cascade network) I modified for 3D image. But not being able to match the size of the layer, problem is occurring in concatenation operations.

import torch
import torch.nn as nn
from torch.nn import ReLU, LeakyReLU
import torch.nn.functional as F

def conv(dim = 3):
return nn.Conv3d

def trans_conv(dim = 3):
return nn.ConvTranspose3d

def convolve(in_channels, out_channels, kernel_size, stride, dim=3):
return conv(dim=dim)(in_channels, out_channels, kernel_size, stride=stride, padding=1)

def convolveReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(ReLU, convolve(in_channels, out_channels, kernel_size, stride, dim=dim))

def convolveLeakyReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(LeakyReLU(0.1), convolve(in_channels, out_channels, kernel_size, stride, dim=dim))

def upconvolve(in_channels, out_channels, kernel_size, stride, dim=3):
return trans_conv(dim=dim)(in_channels, out_channels, kernel_size, stride, padding=1)

def upconvolveReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(ReLU, upconvolve(in_channels, out_channels, kernel_size, stride, dim=dim))

def upconvolveLeakyReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(LeakyReLU(0.1), upconvolve(in_channels, out_channels, kernel_size, stride, dim=dim))

class VTN(nn.Module):
def init(self, dim=3, flow_multiplier=1., channels=16):
super(VTN, self).init()
self.flow_multiplier = flow_multiplier
self.channels = channels
self.dim = dim

    self.conv1 = convolveLeakyReLU(2, channels, 3, 2, dim=dim)
    self.conv2 = convolveLeakyReLU(channels, 2 * channels, 3, 2, dim=dim)
    self.conv3 = convolveLeakyReLU(2 * channels, 4 * channels, 3, 2, dim=dim)
    self.conv3_1 = convolveLeakyReLU(4 * channels, 4 * channels, 3, 1, dim=dim)
    self.conv4 = convolveLeakyReLU(4 * channels, 8 * channels, 3, 2, dim=dim)
    self.conv4_1 = convolveLeakyReLU(8 * channels, 8 * channels, 3, 1, dim=dim)


    self.pred4 = convolve(8 * channels, dim, 3, 1, dim=dim)
    self.upsamp4to3 = upconvolve(dim, dim, 4, 2, dim=dim)
    self.deconv3 = upconvolveLeakyReLU(8 * channels, 4 * channels, 4, 2, dim=dim)

    self.pred3 = convolve(8 * channels+dim, dim, 3, 1, dim=dim)
    self.upsamp3to2 = upconvolve(dim, dim, 4, 2, dim=dim)
    self.deconv2 = upconvolveLeakyReLU(8 * channels+dim, 2 * channels, 4, 2, dim=dim)

    self.pred2 = convolve(4 * channels + dim, dim, 3, 1, dim=dim)
    self.upsamp2to1 = upconvolve(dim, dim, 4, 2, dim=dim)
    self.deconv1 = upconvolveLeakyReLU(4 * channels + dim, channels, 4, 2, dim=dim)

    self.pred0 = upconvolve(2 * channels + dim, dim, 4, 2, dim=dim)

def forward(self, fixed, moving):
    concat_image = torch.cat((fixed, moving), dim=1)  
    
    x1 = self.conv1(concat_image)
    x2 = self.conv2(x1)
    x3 = self.conv3(x2)
    x3_1 = self.conv3_1(x3)
    x4 = self.conv4(x3_1)
    x4_1 = self.conv4_1(x4)

    pred4 = self.pred4(x4_1)
    upsamp4to3 = self.upsamp4to3(pred4)
    deconv3 = self.deconv3(x4_1)
    concat3 = torch.cat([x3_1, deconv3, upsamp4to3], dim=1)

    pred3 = self.pred3(concat3)
    upsamp3to2 = self.upsamp3to2(pred3)
    deconv2 = self.deconv2(concat3)
    concat2 = torch.cat([x2, deconv2, upsamp3to2], dim=1)

    pred2 = self.pred2(concat2)  
    upsamp2to1 = self.upsamp2to1(pred2)  
    deconv1 = self.deconv1(concat2)  
    concat1 = torch.cat([x1, deconv1, upsamp2to1], dim=1) 

    pred0 = self.pred0(concat1) 

    return pred0 * 20 * self.flow_multiplier  

class VTNAffineStem(nn.Module):

def __init__(self, dim=3, channels=16, flow_multiplier=1., im_size=128):
    super(VTNAffineStem, self).__init__()
    self.flow_multiplier = flow_multiplier
    self.channels = channels
    self.dim = dim

    # Network architecture
    # The first convolution's input is the concatenated image
    self.conv1 = convolveLeakyReLU(2, channels, 3, 2, dim=self.dim)
    self.conv2 = convolveLeakyReLU(channels, 2 * channels, 3, 2, dim=dim)
    self.conv3 = convolveLeakyReLU(2 * channels, 4 * channels, 3, 2, dim=dim)
    self.conv3_1 = convolveLeakyReLU(4 * channels, 4 * channels, 3, 1, dim=dim)
    self.conv4 = convolveLeakyReLU(4 * channels, 8 * channels, 3, 2, dim=dim)
    self.conv4_1 = convolveLeakyReLU(8 * channels, 8 * channels, 3, 1, dim=dim)


    self.last_conv_size = im_size // (self.channels * 4)
    #print(self.last_conv_size)

    self.fc_loc = nn.Sequential(
        nn.Linear(128 * self.last_conv_size**dim, 2048),
        nn.ReLU(True),
        nn.Dropout(0.5),
        nn.Linear(2048, 1024),
        nn.ReLU(True),
        nn.Dropout(0.5),
        nn.Linear(1024, 4),
        nn.ReLU(True),
        nn.Dropout(0.5),
        nn.Linear(4, 6*(dim - 1))
    )
    self.fc_loc[-1].weight.data.zero_()
    self.fc_loc[-1].bias.data.copy_(torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], dtype=torch.float))

def forward(self, fixed, moving):

    #moving = torch.reshape(moving, fixed.shape)
    concat_image = torch.cat((fixed, moving), dim=1)  # 2 x 512 x 512
    x1 = self.conv1(concat_image)  # 16 x 256 x 256
    x2 = self.conv2(x1)  # 32 x 128 x 128
    x3 = self.conv3(x2)  # 1 x 64 x 64 x 64
    x3_1 = self.conv3_1(x3)  # 64 x 64 x 64
    x4 = self.conv4(x3_1)  # 128 x 32 x 32
    x4_1 = self.conv4_1(x4)  # 128 x 32 x 32
    

    # Affine transformation
    xs = x4_1.view(-1, 128 * self.last_conv_size ** self.dim)
    theta = self.fc_loc(xs).view(-1, 3, 4)
    flow = F.affine_grid(theta, moving.size(), align_corners=False)  # batch x 512 x 512 x 2
    flow = flow.permute(0, 4, 1, 2, 3)

    return flow

if name == “main”:
vtn_model = VTN(dim=3)
vtn_affine_model = VTNAffineStem(dim=3, im_size=128)
x = torch.randn(1, 1, 90, 128, 128)
y1 = vtn_model(x, x)
y2 = vtn_affine_model(x, x)
assert y1.size() == y2.size()

What kind of error are you seeing?
Assuming a shape mismatch is created in this modified model but works in the 2D version?

PS: you can post code snippets by wrapping them into three backticks ``` :wink:

I am getting the following error in concat2 = torch.cat([x2, deconv2, upsamp3to2], dim=1) this line of code from VTN function.
torch.cat(): Sizes of tensors must match except in dimension 1. Got 23 and 24 in dimension 2 (The offending index is 1)

I guess it is becuase default shape of the image is should be like image size, image size, image size, suppose like 128,128,128. I am taking 3D images where no of patches of the fixed image and moving image are not same. So I am trying to make this model take input of 1,1, any value, 128, 128.

Since you are working with a variable input shape, you would have to make sure the activation shapes match by e.g. padding/slicing them or a model setup creating defined output shapes.