Here is the code, which was originally written for 2D image registration,(recursive cascade network) I modified for 3D image. But not being able to match the size of the layer, problem is occurring in concatenation operations.
import torch
import torch.nn as nn
from torch.nn import ReLU, LeakyReLU
import torch.nn.functional as F
def conv(dim = 3):
return nn.Conv3d
def trans_conv(dim = 3):
return nn.ConvTranspose3d
def convolve(in_channels, out_channels, kernel_size, stride, dim=3):
return conv(dim=dim)(in_channels, out_channels, kernel_size, stride=stride, padding=1)
def convolveReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(ReLU, convolve(in_channels, out_channels, kernel_size, stride, dim=dim))
def convolveLeakyReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(LeakyReLU(0.1), convolve(in_channels, out_channels, kernel_size, stride, dim=dim))
def upconvolve(in_channels, out_channels, kernel_size, stride, dim=3):
return trans_conv(dim=dim)(in_channels, out_channels, kernel_size, stride, padding=1)
def upconvolveReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(ReLU, upconvolve(in_channels, out_channels, kernel_size, stride, dim=dim))
def upconvolveLeakyReLU(in_channels, out_channels, kernel_size, stride, dim=3):
return nn.Sequential(LeakyReLU(0.1), upconvolve(in_channels, out_channels, kernel_size, stride, dim=dim))
class VTN(nn.Module):
def init(self, dim=3, flow_multiplier=1., channels=16):
super(VTN, self).init()
self.flow_multiplier = flow_multiplier
self.channels = channels
self.dim = dim
self.conv1 = convolveLeakyReLU(2, channels, 3, 2, dim=dim)
self.conv2 = convolveLeakyReLU(channels, 2 * channels, 3, 2, dim=dim)
self.conv3 = convolveLeakyReLU(2 * channels, 4 * channels, 3, 2, dim=dim)
self.conv3_1 = convolveLeakyReLU(4 * channels, 4 * channels, 3, 1, dim=dim)
self.conv4 = convolveLeakyReLU(4 * channels, 8 * channels, 3, 2, dim=dim)
self.conv4_1 = convolveLeakyReLU(8 * channels, 8 * channels, 3, 1, dim=dim)
self.pred4 = convolve(8 * channels, dim, 3, 1, dim=dim)
self.upsamp4to3 = upconvolve(dim, dim, 4, 2, dim=dim)
self.deconv3 = upconvolveLeakyReLU(8 * channels, 4 * channels, 4, 2, dim=dim)
self.pred3 = convolve(8 * channels+dim, dim, 3, 1, dim=dim)
self.upsamp3to2 = upconvolve(dim, dim, 4, 2, dim=dim)
self.deconv2 = upconvolveLeakyReLU(8 * channels+dim, 2 * channels, 4, 2, dim=dim)
self.pred2 = convolve(4 * channels + dim, dim, 3, 1, dim=dim)
self.upsamp2to1 = upconvolve(dim, dim, 4, 2, dim=dim)
self.deconv1 = upconvolveLeakyReLU(4 * channels + dim, channels, 4, 2, dim=dim)
self.pred0 = upconvolve(2 * channels + dim, dim, 4, 2, dim=dim)
def forward(self, fixed, moving):
concat_image = torch.cat((fixed, moving), dim=1)
x1 = self.conv1(concat_image)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x3_1 = self.conv3_1(x3)
x4 = self.conv4(x3_1)
x4_1 = self.conv4_1(x4)
pred4 = self.pred4(x4_1)
upsamp4to3 = self.upsamp4to3(pred4)
deconv3 = self.deconv3(x4_1)
concat3 = torch.cat([x3_1, deconv3, upsamp4to3], dim=1)
pred3 = self.pred3(concat3)
upsamp3to2 = self.upsamp3to2(pred3)
deconv2 = self.deconv2(concat3)
concat2 = torch.cat([x2, deconv2, upsamp3to2], dim=1)
pred2 = self.pred2(concat2)
upsamp2to1 = self.upsamp2to1(pred2)
deconv1 = self.deconv1(concat2)
concat1 = torch.cat([x1, deconv1, upsamp2to1], dim=1)
pred0 = self.pred0(concat1)
return pred0 * 20 * self.flow_multiplier
class VTNAffineStem(nn.Module):
def __init__(self, dim=3, channels=16, flow_multiplier=1., im_size=128):
super(VTNAffineStem, self).__init__()
self.flow_multiplier = flow_multiplier
self.channels = channels
self.dim = dim
# Network architecture
# The first convolution's input is the concatenated image
self.conv1 = convolveLeakyReLU(2, channels, 3, 2, dim=self.dim)
self.conv2 = convolveLeakyReLU(channels, 2 * channels, 3, 2, dim=dim)
self.conv3 = convolveLeakyReLU(2 * channels, 4 * channels, 3, 2, dim=dim)
self.conv3_1 = convolveLeakyReLU(4 * channels, 4 * channels, 3, 1, dim=dim)
self.conv4 = convolveLeakyReLU(4 * channels, 8 * channels, 3, 2, dim=dim)
self.conv4_1 = convolveLeakyReLU(8 * channels, 8 * channels, 3, 1, dim=dim)
self.last_conv_size = im_size // (self.channels * 4)
#print(self.last_conv_size)
self.fc_loc = nn.Sequential(
nn.Linear(128 * self.last_conv_size**dim, 2048),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(2048, 1024),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(1024, 4),
nn.ReLU(True),
nn.Dropout(0.5),
nn.Linear(4, 6*(dim - 1))
)
self.fc_loc[-1].weight.data.zero_()
self.fc_loc[-1].bias.data.copy_(torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], dtype=torch.float))
def forward(self, fixed, moving):
#moving = torch.reshape(moving, fixed.shape)
concat_image = torch.cat((fixed, moving), dim=1) # 2 x 512 x 512
x1 = self.conv1(concat_image) # 16 x 256 x 256
x2 = self.conv2(x1) # 32 x 128 x 128
x3 = self.conv3(x2) # 1 x 64 x 64 x 64
x3_1 = self.conv3_1(x3) # 64 x 64 x 64
x4 = self.conv4(x3_1) # 128 x 32 x 32
x4_1 = self.conv4_1(x4) # 128 x 32 x 32
# Affine transformation
xs = x4_1.view(-1, 128 * self.last_conv_size ** self.dim)
theta = self.fc_loc(xs).view(-1, 3, 4)
flow = F.affine_grid(theta, moving.size(), align_corners=False) # batch x 512 x 512 x 2
flow = flow.permute(0, 4, 1, 2, 3)
return flow
if name == “main”:
vtn_model = VTN(dim=3)
vtn_affine_model = VTNAffineStem(dim=3, im_size=128)
x = torch.randn(1, 1, 90, 128, 128)
y1 = vtn_model(x, x)
y2 = vtn_affine_model(x, x)
assert y1.size() == y2.size()