class _EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout=False):
super(_EncoderBlock, self).__init__()
self.e_mxpl1 = nn.MaxPool3d(kernel_size=2, stride=2)
self.e_conv1 = nn.Conv3d(in_channels, out_channels,kernel_size=3),#Conv3dSep(in_channels, out_channels),
self.e_bn1 = nn.BatchNorm3d(out_channels),
self.e_relu1 = nn.ReLU(inplace=True),
self.e_conv2 = nn.Conv3d(out_channels, out_channels,kernel_size=3),#Conv3dSep(out_channels, out_channels),
self.e_bn2 = nn.BatchNorm3d(out_channels),
self.e_relu2 =nn.ReLU(inplace=True),
self.e_drp = nn.Dropout()
def forward(self, x):
x = self.e_mxpl1(x)
x = self.e_conv1(x)
x = self.e_bn1(x)
x = self.e_relu1(x)
x = self.e_conv2(x)
x = self.e_bn2(x)
x = self.e_relu2(x)
if dropout:
x = self.e_drp(x)
return x
class _DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super(_DecoderBlock, self).__init__()
self.d_conv1 = nn.Conv3d(in_channels, middle_channels,kernel_size=3),#Conv3dSep(in_channels, middle_channels),
self.d_bn1 = nn.BatchNorm3d(middle_channels),
self.d_relu1 = nn.ReLU(inplace=True),
self.d_conv2 = nn.Conv3d(middle_channels, middle_channels,kernel_size=3),#Conv3dSep(middle_channels, middle_channels),
self.d_bn2 = nn.BatchNorm3d(middle_channels),
self.d_relu2 = nn.ReLU(inplace=True),
self.d_convT = nn.ConvTranspose3d(middle_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
x = self.d_conv1(x)
x = self.d_bn1(x)
x = self.d_relu1(x)
x = self.d_conv2(x)
x = self.d_bn2(x)
x = self.d_relu2(x)
x = self.d_convT(x)
return x
class WNet3D(nn.Module):
def __init__(self, num_classes=1):
super(WNet3D, self).__init__()
self.module_1 = nn.Sequential(
nn.Conv3d(1, 64,kernel_size=3),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64, 64,kernel_size=3),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
)
self.module_2 = _EncoderBlock(64, 128)
self.module_3 = _EncoderBlock(128, 256)
self.module_4 = _EncoderBlock(256, 512, dropout=True)
self.mxpl_enc = nn.MaxPool3d(kernel_size=2, stride=2)
self.module_5 = _DecoderBlock(512, 1024, 512)
self.module_6 = _DecoderBlock(1024, 512, 256)
self.module_7 = _DecoderBlock(512, 256, 128)
self.module_8 = _DecoderBlock(256, 128, 64)
self.module_9 = nn.Sequential(
nn.Conv3d(128,64,kernel_size=3),#Conv3dSep(128, 64),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64,64,kernel_size=3),#Conv3dSep(64, 64),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
)
self.mid = nn.Conv3d(64, num_classes, kernel_size=1)
self.sm = nn.Softmax(dim=4)
self.module_10 = _EncoderBlock(1,64)
self.module_11 = _EncoderBlock(64,128)
self.module_12 = _EncoderBlock(128, 256)
self.module_13 = _EncoderBlock(256, 512, dropout=True)
self.mxpl_dec = nn.MaxPool3d(kernel_size=2, stride=2)
self.module_14 = _DecoderBlock(512, 1024, 512)
self.module_15 = _DecoderBlock(1024, 512, 256)
self.module_16 = _DecoderBlock(512, 256, 128)
self.module_17 = _DecoderBlock(256, 128, 64)
self.module_18 = nn.Sequential(
nn.Conv3d(128,64,kernel_size=3),#Conv3dSep(128, 64),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.Conv3d(64,64,kernel_size=3),#Conv3dSep(64, 64),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
)
self.final = nn.Conv3d(64, num_classes, kernel_size=1)
# self._initialize_weights()
def forward(self, x):
enc1 = self.module_1(x)
enc2 = self.module_2(enc1)
enc3 = self.module_3(enc2)
enc4 = self.module_4(enc3)
enc4 = self.mxpl_enc(enc4)
center = self.module_5(enc4)
dec4 = self.module_6(torch.cat([center, F.upsample(enc4, center.size()[2:], mode='bilinear')], 1))
dec3 = self.module_7(torch.cat([dec4, F.upsample(enc3, dec4.size()[2:], mode='bilinear')], 1))
dec2 = self.module_8(torch.cat([dec3, F.upsample(enc2, dec3.size()[2:], mode='bilinear')], 1))
dec1 = self.module_9(torch.cat([dec2, F.upsample(enc1, dec2.size()[2:], mode='bilinear')], 1))
mid = self.sm(self.mid(dec1))
middle = F.upsample(mid, x.size()[2:], mode='bilinear')
enc5 = self.module_10(middle)
enc6 = self.module_11(enc5)
enc7 = self.module_12(enc6)
enc8 = self.module_13(enc7)
enc8 = self.mxpl_dec(enc8)
center = self.module_14(enc8)
dec8 = self.module_15(torch.cat([center, F.upsample(enc8, center.size()[2:], mode='bilinear')], 1))
dec7 = self.module_16(torch.cat([dec8, F.upsample(enc7, dec8.size()[2:], mode='bilinear')], 1))
dec6 = self.module_17(torch.cat([dec7, F.upsample(enc6, dec7.size()[2:], mode='bilinear')], 1))
dec5 = self.module_18(torch.cat([dec6, F.upsample(enc5, dec6.size()[2:], mode='bilinear')], 1))
fin = self.final(dec5)
final = F.upsample(fin, x.size()[2:], mode='bilinear')
return middle,final