Dear senior programmers,
I am still at my first steps with pytorch and programming in general. I am mainly focusing on binary segmentation tasks. I have managed to train a V-Net structure network with my own data. However, the results are quite bad (dice=0.63). I would like to include some dilated convolutional layers into the network structure. Please, could anyone explain to me how to achieve that. My network is as follows.
import torch
from torch import nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i==0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class ResidualConvBlock(nn.Module):
def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
super(ResidualConvBlock, self).__init__()
ops = []
for i in range(n_stages):
if i == 0:
input_channel = n_filters_in
else:
input_channel = n_filters_out
ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
if i != n_stages-1:
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = (self.conv(x) + x)
x = self.relu(x)
return x
class DownsamplingConvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(DownsamplingConvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class UpsamplingDeconvBlock(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(UpsamplingDeconvBlock, self).__init__()
ops = []
if normalization != 'none':
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
else:
assert False
else:
ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class Upsampling(nn.Module):
def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
super(Upsampling, self).__init__()
ops = []
ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
if normalization == 'batchnorm':
ops.append(nn.BatchNorm3d(n_filters_out))
elif normalization == 'groupnorm':
ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
elif normalization == 'instancenorm':
ops.append(nn.InstanceNorm3d(n_filters_out))
elif normalization != 'none':
assert False
ops.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*ops)
def forward(self, x):
x = self.conv(x)
return x
class VNet(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):
super(VNet, self).__init__()
self.has_dropout = has_dropout
self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)
self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)
self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)
self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)
self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization)
self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
self.dropout = nn.Dropout3d(p=0.5, inplace=False)
self.sigmoid = nn.Sigmoid()
# self.__init_weight()
def encoder(self, input):
x1 = self.block_one(input)
x1_dw = self.block_one_dw(x1)
x2 = self.block_two(x1_dw)
x2_dw = self.block_two_dw(x2)
x3 = self.block_three(x2_dw)
x3_dw = self.block_three_dw(x3)
x4 = self.block_four(x3_dw)
x4_dw = self.block_four_dw(x4)
x5 = self.block_five(x4_dw)
# x5 = F.dropout3d(x5, p=0.5, training=True)
if self.has_dropout:
x5 = self.dropout(x5)
res = [x1, x2, x3, x4, x5]
return res
def decoder(self, features):
x1 = features[0]
x2 = features[1]
x3 = features[2]
x4 = features[3]
x5 = features[4]
x5_up = self.block_five_up(x5)
x5_up = x5_up + x4
x6 = self.block_six(x5_up)
x6_up = self.block_six_up(x6)
x6_up = x6_up + x3
x7 = self.block_seven(x6_up)
x7_up = self.block_seven_up(x7)
x7_up = x7_up + x2
x8 = self.block_eight(x7_up)
x8_up = self.block_eight_up(x8)
x8_up = x8_up + x1
x9 = self.block_nine(x8_up)
# x9 = F.dropout3d(x9, p=0.5, training=True)
if self.has_dropout:
x9 = self.dropout(x9)
out = self.out_conv(x9)
out = self.sigmoid(out)
return out
def forward(self, input, turnoff_drop=False):
if turnoff_drop:
has_dropout = self.has_dropout
self.has_dropout = False
features = self.encoder(input)
out = self.decoder(features)
if turnoff_drop:
self.has_dropout = has_dropout
return out
Any help and/or suggestions would be highly appreciated.