Hello. I have built U-net model with conv2d, batch-normalization, relu function and etc.
while training with DICE loss function, accuracy and losses are very weird. It doesn’t learn at all.
I checked the model with torchsummary library. It turns out that several batch_normalization and relu
layers are added. But I don’t know why it happened. Can you explain why those layers are added ?
thanks in advance.
Below is my code and summary of my model
tclass U_net(nn.Module):
def __init__(self, in_channel, n_classes):
super(U_net, self).__init__()
#input image channels
self.in_channel = in_channel
# contractive path conv kernel size:(3x3)
self.c_kernel_size = 3
# expansive path conv kernel size:(3x3)
self.up_kernel_size = 2
self.final_kernel_size = 1
self.max_pool_kernel_size = 2
self.max_pool_stride = 2
self.deconv_stride = 2
# activation function
self.c_activation_f = nn.ReLU(inplace=True)
# channels
self.n_out_channels1 = 64
self.n_out_channels2 = 128
self.n_out_channels3 = 256
self.n_out_channels4 = 512
self.n_out_channels5 = 1024
self.final_out_channels = n_classes
self.n_skip_channels1 = 128
self.n_skip_channels2 = 256
self.n_skip_channels3 = 512
self.n_skip_channels4 = 1024
self.c_max_pooling = nn.MaxPool2d(self.max_pool_kernel_size, stride=self.max_pool_stride)
# Batch normalization
self.c_bn11 = nn.BatchNorm2d(self.n_out_channels1)
self.c_bn12 = nn.BatchNorm2d(self.n_out_channels1)
self.c_bn21 = nn.BatchNorm2d(self.n_out_channels2)
self.c_bn22 = nn.BatchNorm2d(self.n_out_channels2)
self.c_bn31 = nn.BatchNorm2d(self.n_out_channels3)
self.c_bn32 = nn.BatchNorm2d(self.n_out_channels3)
self.c_bn41 = nn.BatchNorm2d(self.n_out_channels4)
self.c_bn42 = nn.BatchNorm2d(self.n_out_channels4)
self.c_bn51 = nn.BatchNorm2d(self.n_out_channels5)
self.c_bn52 = nn.BatchNorm2d(self.n_out_channels5)
### contractive path layers ###
self.cont_layer1 = nn.Sequential(OrderedDict([
('c_conv11', nn.Conv2d(in_channels=self.in_channel, out_channels=self.n_out_channels1, kernel_size=self.c_kernel_size)),
('c_bn_11', self.c_bn11),
('c_act_f11', self.c_activation_f),
('c_conv12', nn.Conv2d(in_channels=self.n_out_channels1, out_channels=self.n_out_channels1, kernel_size=self.c_kernel_size)),
('c_bn_12', self.c_bn12),
('c_act_f12', self.c_activation_f)
]))
# 2th layer
self.cont_layer2 = nn.Sequential(OrderedDict([
('c_conv21', nn.Conv2d(in_channels=self.n_out_channels1, out_channels=self.n_out_channels2, kernel_size=self.c_kernel_size)),
('c_bn_21', self.c_bn21),
('c_act_f21', self.c_activation_f),
('c_conv22', nn.Conv2d(in_channels=self.n_out_channels2, out_channels=self.n_out_channels2, kernel_size=self.c_kernel_size)),
('c_bn_22', self.c_bn22),
('c_act_f22', self.c_activation_f)
]))
# 3th layer
self.cont_layer3 = nn.Sequential(OrderedDict([
('c_conv31', nn.Conv2d(in_channels=self.n_out_channels2, out_channels=self.n_out_channels3, kernel_size=self.c_kernel_size)),
('c_bn_31', self.c_bn31),
('c_act_f31', self.c_activation_f),
('c_bn_32', self.c_bn32),
('c_conv32', nn.Conv2d(in_channels=self.n_out_channels3, out_channels=self.n_out_channels3, kernel_size=self.c_kernel_size)),
('c_act_f32', self.c_activation_f)
]))
# 4th layer
self.cont_layer4 = nn.Sequential(OrderedDict([
('c_conv41', nn.Conv2d(in_channels=self.n_out_channels3, out_channels=self.n_out_channels4, kernel_size=self.c_kernel_size)),
('c_bn_41', self.c_bn41),
('c_act_f41', self.c_activation_f),
('c_bn_42', self.c_bn42),
('c_conv42', nn.Conv2d(in_channels=self.n_out_channels4, out_channels=self.n_out_channels4, kernel_size=self.c_kernel_size)),
('c_act_f42', self.c_activation_f),
]))
# 5th layer
self.cont_layer5 = nn.Sequential(OrderedDict([
('c_conv51', nn.Conv2d(in_channels=self.n_out_channels4, out_channels=self.n_out_channels5, kernel_size=self.c_kernel_size)),
('c_bn_51', self.c_bn51),
('c_act_f51', self.c_activation_f),
('c_conv52', nn.Conv2d(in_channels=self.n_out_channels5, out_channels=self.n_out_channels5, kernel_size=self.c_kernel_size)),
('c_bn_52', self.c_bn52),
('c_act_f52', self.c_activation_f),
]))
### expansive path layers ###
self.exp_layer5 = nn.ConvTranspose2d(in_channels=self.n_out_channels5, out_channels=self.n_out_channels4,
kernel_size=self.up_kernel_size, stride=self.deconv_stride)
# 4th layer
self.exp_layer4 = nn.Sequential(OrderedDict([
('e_conv41', nn.Conv2d(in_channels=self.n_skip_channels4, out_channels=self.n_out_channels4, kernel_size=self.c_kernel_size)),
('e_act_f41', self.c_activation_f),
('e_conv42', nn.Conv2d(in_channels=self.n_out_channels4, out_channels=self.n_out_channels4, kernel_size=self.c_kernel_size)),
('e_act_f42', self.c_activation_f),
('e_up_conv4', nn.ConvTranspose2d(in_channels=self.n_out_channels4, out_channels=self.n_out_channels3,
kernel_size=self.up_kernel_size, stride=self.deconv_stride))
]))
# 3th layer
self.exp_layer3 = nn.Sequential(OrderedDict([
('e_conv31', nn.Conv2d(in_channels=self.n_skip_channels3, out_channels=self.n_out_channels3, kernel_size=self.c_kernel_size)),
('e_act_f31', self.c_activation_f),
('e_conv32', nn.Conv2d(in_channels=self.n_out_channels3, out_channels=self.n_out_channels3, kernel_size=self.c_kernel_size)),
('e_act_f32', self.c_activation_f),
('e_up_conv3', nn.ConvTranspose2d(in_channels=self.n_out_channels3, out_channels=self.n_out_channels2,
kernel_size=self.up_kernel_size, stride=self.deconv_stride))
]))
# 2th layer
self.exp_layer2 = nn.Sequential(OrderedDict([
('e_conv21', nn.Conv2d(in_channels=self.n_skip_channels2, out_channels=self.n_out_channels2, kernel_size=self.c_kernel_size)),
('e_act_f21', self.c_activation_f),
('e_conv22', nn.Conv2d(in_channels=self.n_out_channels2, out_channels=self.n_out_channels2, kernel_size=self.c_kernel_size)),
('e_act_f22', self.c_activation_f),
('e_up_conv2', nn.ConvTranspose2d(in_channels=self.n_out_channels2, out_channels=self.n_out_channels1,
kernel_size=self.up_kernel_size, stride=self.deconv_stride))
]))
# 1th layer
self.exp_layer1 = nn.Sequential(OrderedDict([
('e_conv11', nn.Conv2d(in_channels=self.n_skip_channels1, out_channels=self.n_out_channels1, kernel_size=self.c_kernel_size)),
('e_act_f11', self.c_activation_f),
('e_conv12', nn.Conv2d(in_channels=self.n_out_channels1, out_channels=self.n_out_channels1, kernel_size=self.c_kernel_size)),
('e_act_f12', self.c_activation_f),
('e_conv_f', nn.Conv2d(in_channels=self.n_out_channels1, out_channels=self.final_out_channels,
kernel_size=self.final_kernel_size))
]))
#### without drop-out implemented
# skip operation : skip -> crop and concatenate
# return concat [n_batch, n_ch, x, h] [n_batch, n_ch, x, h] -> [n_batch, n_ch+n_ch, x, h]
def skipped_connection(self, cont_maps, exp_maps, height, width):
cropped_f_maps = self.crop_feature_maps(cont_maps, height, width)
return torch.cat((cropped_f_maps, exp_maps), 1)
#features = [batchs, n_channels, height , width]
# h,w = crop후 image size
def crop_feature_maps(self, features, h, w):
h_old, w_old = features[0][0].size()
x = math.ceil((h_old - h) / 2)
y = math.ceil((w_old - w) / 2)
return features[:,:, x:(x + h), y:(y + w)]
def contracting_path(self, x):
self.cont_layer1_out = self.cont_layer1(x)
self.cont_layer2_in = self.c_max_pooling(self.cont_layer1_out)
self.cont_layer2_out = self.cont_layer2(self.cont_layer2_in)
self.cont_layer3_in = self.c_max_pooling(self.cont_layer2_out)
self.cont_layer3_out = self.cont_layer3(self.cont_layer3_in)
self.cont_layer4_in = self.c_max_pooling(self.cont_layer3_out)
self.cont_layer4_out = self.cont_layer4(self.cont_layer4_in)
self.cont_layer5_in = self.c_max_pooling(self.cont_layer4_out)
self.cont_layer5_out = self.cont_layer5(self.cont_layer5_in)
return self.cont_layer5_out
# x = cont_layer5_out
def expansive_path(self, x):
x = self.exp_layer5(x)
x = self.skipped_connection(self.cont_layer4_out, x, x.size()[2], x.size()[3])
x = self.exp_layer4(x)
x = self.skipped_connection(self.cont_layer3_out, x, x.size()[2], x.size()[3])
x = self.exp_layer3(x)
x = self.skipped_connection(self.cont_layer2_out, x, x.size()[2], x.size()[3] )
x = self.exp_layer2(x)
x = self.skipped_connection(self.cont_layer1_out, x, x.size()[2], x.size()[3] )
x = self.exp_layer1(x)
return x
# input_x has to be shape of (n_batches, n_channels, height, width)
def forward(self, input_x):
o_x = self.contracting_path(input_x)
o_x = self.expansive_path(o_x)
return o_x
summary:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 253, 253] 640
BatchNorm2d-2 [-1, 64, 253, 253] 128
BatchNorm2d-3 [-1, 64, 253, 253] 128
ReLU-4 [-1, 64, 253, 253] 0
ReLU-5 [-1, 64, 253, 253] 0
ReLU-6 [-1, 64, 253, 253] 0
ReLU-7 [-1, 64, 253, 253] 0
ReLU-8 [-1, 64, 253, 253] 0
ReLU-9 [-1, 64, 253, 253] 0
ReLU-10 [-1, 64, 253, 253] 0
ReLU-11 [-1, 64, 253, 253] 0
ReLU-12 [-1, 64, 253, 253] 0
ReLU-13 [-1, 64, 253, 253] 0
Conv2d-14 [-1, 64, 251, 251] 36,928
BatchNorm2d-15 [-1, 64, 251, 251] 128
BatchNorm2d-16 [-1, 64, 251, 251] 128
ReLU-17 [-1, 64, 251, 251] 0
ReLU-18 [-1, 64, 251, 251] 0
ReLU-19 [-1, 64, 251, 251] 0
ReLU-20 [-1, 64, 251, 251] 0
ReLU-21 [-1, 64, 251, 251] 0
ReLU-22 [-1, 64, 251, 251] 0
ReLU-23 [-1, 64, 251, 251] 0
ReLU-24 [-1, 64, 251, 251] 0
ReLU-25 [-1, 64, 251, 251] 0
ReLU-26 [-1, 64, 251, 251] 0
MaxPool2d-27 [-1, 64, 125, 125] 0
Conv2d-28 [-1, 128, 123, 123] 73,856
BatchNorm2d-29 [-1, 128, 123, 123] 256
BatchNorm2d-30 [-1, 128, 123, 123] 256
ReLU-31 [-1, 128, 123, 123] 0
ReLU-32 [-1, 128, 123, 123] 0
ReLU-33 [-1, 128, 123, 123] 0
ReLU-34 [-1, 128, 123, 123] 0
ReLU-35 [-1, 128, 123, 123] 0
ReLU-36 [-1, 128, 123, 123] 0
ReLU-37 [-1, 128, 123, 123] 0
ReLU-38 [-1, 128, 123, 123] 0
ReLU-39 [-1, 128, 123, 123] 0
ReLU-40 [-1, 128, 123, 123] 0
Conv2d-41 [-1, 128, 121, 121] 147,584
BatchNorm2d-42 [-1, 128, 121, 121] 256
BatchNorm2d-43 [-1, 128, 121, 121] 256
ReLU-44 [-1, 128, 121, 121] 0
ReLU-45 [-1, 128, 121, 121] 0
ReLU-46 [-1, 128, 121, 121] 0
ReLU-47 [-1, 128, 121, 121] 0
ReLU-48 [-1, 128, 121, 121] 0
ReLU-49 [-1, 128, 121, 121] 0
ReLU-50 [-1, 128, 121, 121] 0
ReLU-51 [-1, 128, 121, 121] 0
ReLU-52 [-1, 128, 121, 121] 0
ReLU-53 [-1, 128, 121, 121] 0
MaxPool2d-54 [-1, 128, 60, 60] 0
Conv2d-55 [-1, 256, 58, 58] 295,168
BatchNorm2d-56 [-1, 256, 58, 58] 512
BatchNorm2d-57 [-1, 256, 58, 58] 512
ReLU-58 [-1, 256, 58, 58] 0
ReLU-59 [-1, 256, 58, 58] 0
ReLU-60 [-1, 256, 58, 58] 0
ReLU-61 [-1, 256, 58, 58] 0
ReLU-62 [-1, 256, 58, 58] 0
ReLU-63 [-1, 256, 58, 58] 0
ReLU-64 [-1, 256, 58, 58] 0
ReLU-65 [-1, 256, 58, 58] 0
ReLU-66 [-1, 256, 58, 58] 0
ReLU-67 [-1, 256, 58, 58] 0
BatchNorm2d-68 [-1, 256, 58, 58] 512
BatchNorm2d-69 [-1, 256, 58, 58] 512
Conv2d-70 [-1, 256, 56, 56] 590,080
ReLU-71 [-1, 256, 56, 56] 0
ReLU-72 [-1, 256, 56, 56] 0
ReLU-73 [-1, 256, 56, 56] 0
ReLU-74 [-1, 256, 56, 56] 0
ReLU-75 [-1, 256, 56, 56] 0
ReLU-76 [-1, 256, 56, 56] 0
ReLU-77 [-1, 256, 56, 56] 0
ReLU-78 [-1, 256, 56, 56] 0
ReLU-79 [-1, 256, 56, 56] 0
ReLU-80 [-1, 256, 56, 56] 0
MaxPool2d-81 [-1, 256, 28, 28] 0
Conv2d-82 [-1, 512, 26, 26] 1,180,160
BatchNorm2d-83 [-1, 512, 26, 26] 1,024
BatchNorm2d-84 [-1, 512, 26, 26] 1,024
ReLU-85 [-1, 512, 26, 26] 0
ReLU-86 [-1, 512, 26, 26] 0
ReLU-87 [-1, 512, 26, 26] 0
ReLU-88 [-1, 512, 26, 26] 0
ReLU-89 [-1, 512, 26, 26] 0
ReLU-90 [-1, 512, 26, 26] 0
ReLU-91 [-1, 512, 26, 26] 0
ReLU-92 [-1, 512, 26, 26] 0
ReLU-93 [-1, 512, 26, 26] 0
ReLU-94 [-1, 512, 26, 26] 0
BatchNorm2d-95 [-1, 512, 26, 26] 1,024
BatchNorm2d-96 [-1, 512, 26, 26] 1,024
Conv2d-97 [-1, 512, 24, 24] 2,359,808
ReLU-98 [-1, 512, 24, 24] 0
ReLU-99 [-1, 512, 24, 24] 0
ReLU-100 [-1, 512, 24, 24] 0
ReLU-101 [-1, 512, 24, 24] 0
ReLU-102 [-1, 512, 24, 24] 0
ReLU-103 [-1, 512, 24, 24] 0
ReLU-104 [-1, 512, 24, 24] 0
ReLU-105 [-1, 512, 24, 24] 0
ReLU-106 [-1, 512, 24, 24] 0
ReLU-107 [-1, 512, 24, 24] 0
MaxPool2d-108 [-1, 512, 12, 12] 0
Conv2d-109 [-1, 1024, 10, 10] 4,719,616
BatchNorm2d-110 [-1, 1024, 10, 10] 2,048
BatchNorm2d-111 [-1, 1024, 10, 10] 2,048
ReLU-112 [-1, 1024, 10, 10] 0
ReLU-113 [-1, 1024, 10, 10] 0
ReLU-114 [-1, 1024, 10, 10] 0
ReLU-115 [-1, 1024, 10, 10] 0
ReLU-116 [-1, 1024, 10, 10] 0
ReLU-117 [-1, 1024, 10, 10] 0
ReLU-118 [-1, 1024, 10, 10] 0
ReLU-119 [-1, 1024, 10, 10] 0
ReLU-120 [-1, 1024, 10, 10] 0
ReLU-121 [-1, 1024, 10, 10] 0
Conv2d-122 [-1, 1024, 8, 8] 9,438,208
BatchNorm2d-123 [-1, 1024, 8, 8] 2,048
BatchNorm2d-124 [-1, 1024, 8, 8] 2,048
ReLU-125 [-1, 1024, 8, 8] 0
ReLU-126 [-1, 1024, 8, 8] 0
ReLU-127 [-1, 1024, 8, 8] 0
ReLU-128 [-1, 1024, 8, 8] 0
ReLU-129 [-1, 1024, 8, 8] 0
ReLU-130 [-1, 1024, 8, 8] 0
ReLU-131 [-1, 1024, 8, 8] 0
ReLU-132 [-1, 1024, 8, 8] 0
ReLU-133 [-1, 1024, 8, 8] 0
ReLU-134 [-1, 1024, 8, 8] 0
ConvTranspose2d-135 [-1, 512, 16, 16] 2,097,664
Conv2d-136 [-1, 512, 14, 14] 4,719,104
ReLU-137 [-1, 512, 14, 14] 0
ReLU-138 [-1, 512, 14, 14] 0
ReLU-139 [-1, 512, 14, 14] 0
ReLU-140 [-1, 512, 14, 14] 0
ReLU-141 [-1, 512, 14, 14] 0
ReLU-142 [-1, 512, 14, 14] 0
ReLU-143 [-1, 512, 14, 14] 0
ReLU-144 [-1, 512, 14, 14] 0
ReLU-145 [-1, 512, 14, 14] 0
ReLU-146 [-1, 512, 14, 14] 0
Conv2d-147 [-1, 512, 12, 12] 2,359,808
ReLU-148 [-1, 512, 12, 12] 0
ReLU-149 [-1, 512, 12, 12] 0
ReLU-150 [-1, 512, 12, 12] 0
ReLU-151 [-1, 512, 12, 12] 0
ReLU-152 [-1, 512, 12, 12] 0
ReLU-153 [-1, 512, 12, 12] 0
ReLU-154 [-1, 512, 12, 12] 0
ReLU-155 [-1, 512, 12, 12] 0
ReLU-156 [-1, 512, 12, 12] 0
ReLU-157 [-1, 512, 12, 12] 0
ConvTranspose2d-158 [-1, 256, 24, 24] 524,544
Conv2d-159 [-1, 256, 22, 22] 1,179,904
ReLU-160 [-1, 256, 22, 22] 0
ReLU-161 [-1, 256, 22, 22] 0
ReLU-162 [-1, 256, 22, 22] 0
ReLU-163 [-1, 256, 22, 22] 0
ReLU-164 [-1, 256, 22, 22] 0
ReLU-165 [-1, 256, 22, 22] 0
ReLU-166 [-1, 256, 22, 22] 0
ReLU-167 [-1, 256, 22, 22] 0
ReLU-168 [-1, 256, 22, 22] 0
ReLU-169 [-1, 256, 22, 22] 0
Conv2d-170 [-1, 256, 20, 20] 590,080
ReLU-171 [-1, 256, 20, 20] 0
ReLU-172 [-1, 256, 20, 20] 0
ReLU-173 [-1, 256, 20, 20] 0
ReLU-174 [-1, 256, 20, 20] 0
ReLU-175 [-1, 256, 20, 20] 0
ReLU-176 [-1, 256, 20, 20] 0
ReLU-177 [-1, 256, 20, 20] 0
ReLU-178 [-1, 256, 20, 20] 0
ReLU-179 [-1, 256, 20, 20] 0
ReLU-180 [-1, 256, 20, 20] 0
ConvTranspose2d-181 [-1, 128, 40, 40] 131,200
Conv2d-182 [-1, 128, 38, 38] 295,040
ReLU-183 [-1, 128, 38, 38] 0
ReLU-184 [-1, 128, 38, 38] 0
ReLU-185 [-1, 128, 38, 38] 0
ReLU-186 [-1, 128, 38, 38] 0
ReLU-187 [-1, 128, 38, 38] 0
ReLU-188 [-1, 128, 38, 38] 0
ReLU-189 [-1, 128, 38, 38] 0
ReLU-190 [-1, 128, 38, 38] 0
ReLU-191 [-1, 128, 38, 38] 0
ReLU-192 [-1, 128, 38, 38] 0
Conv2d-193 [-1, 128, 36, 36] 147,584
ReLU-194 [-1, 128, 36, 36] 0
ReLU-195 [-1, 128, 36, 36] 0
ReLU-196 [-1, 128, 36, 36] 0
ReLU-197 [-1, 128, 36, 36] 0
ReLU-198 [-1, 128, 36, 36] 0
ReLU-199 [-1, 128, 36, 36] 0
ReLU-200 [-1, 128, 36, 36] 0
ReLU-201 [-1, 128, 36, 36] 0
ReLU-202 [-1, 128, 36, 36] 0
ReLU-203 [-1, 128, 36, 36] 0
ConvTranspose2d-204 [-1, 64, 72, 72] 32,832
Conv2d-205 [-1, 64, 70, 70] 73,792
ReLU-206 [-1, 64, 70, 70] 0
ReLU-207 [-1, 64, 70, 70] 0
ReLU-208 [-1, 64, 70, 70] 0
ReLU-209 [-1, 64, 70, 70] 0
ReLU-210 [-1, 64, 70, 70] 0
ReLU-211 [-1, 64, 70, 70] 0
ReLU-212 [-1, 64, 70, 70] 0
ReLU-213 [-1, 64, 70, 70] 0
ReLU-214 [-1, 64, 70, 70] 0
ReLU-215 [-1, 64, 70, 70] 0
Conv2d-216 [-1, 64, 68, 68] 36,928
ReLU-217 [-1, 64, 68, 68] 0
ReLU-218 [-1, 64, 68, 68] 0
ReLU-219 [-1, 64, 68, 68] 0
ReLU-220 [-1, 64, 68, 68] 0
ReLU-221 [-1, 64, 68, 68] 0
ReLU-222 [-1, 64, 68, 68] 0
ReLU-223 [-1, 64, 68, 68] 0
ReLU-224 [-1, 64, 68, 68] 0
ReLU-225 [-1, 64, 68, 68] 0
ReLU-226 [-1, 64, 68, 68] 0
Conv2d-227 [-1, 1, 68, 68] 65
================================================================
Total params: 31,046,465
Trainable params: 31,046,465
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 1564.78
Params size (MB): 118.43
Estimated Total Size (MB): 1683.46
----------------------------------------------------------------