I use U-net in image regression, the net is as follows:
class Unet_Resnet_101(nn.Module):
def __init__(self):
super(Unet_Resnet_101, self).__init__()
self.conv_relu1 = nn.Sequential(
Conv2d(input_channel, 64, 5, stride=2, padding=1),
nn.BatchNorm2d(64, eps=1e-3, momentum=0.99),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(3, 2, padding=1)
self.bottleneck_block_1t = BottleBlock(64, 128, stride=1, with_conv_shortcut=True)
self.bottleneck_block_1 = BottleBlock(128, 128)
self.bottleneck_block_2t = BottleBlock(128, 256, stride=2, with_conv_shortcut=True)
self.bottleneck_block_2 = BottleBlock(256, 256)
self.bottleneck_block_3t = BottleBlock(256, 512, stride=2, with_conv_shortcut=True)
self.bottleneck_block_3 = BottleBlock(512, 512)
self.bottleneck_block_4t = BottleBlock(512, 1024, stride=2, with_conv_shortcut=True)
self.bottleneck_block_4 = BottleBlock(1024, 1024)
self.conv_bn1_1 = ConvBN(1024, 512, 2)
self.conv_bn1_2 = ConvBN(1024, 512, 3)
self.conv_bn1_3 = ConvBN(512, 512, 3)
self.conv_bn2_1 = ConvBN(512, 512, 2)
self.conv_bn2_2 = ConvBN(256 + 512, 256, 3)
self.conv_bn2_3 = ConvBN(256, 256, 3)
self.conv_bn3_1 = ConvBN(256, 256, 2)
self.conv_bn3_2 = ConvBN(256 + 128, 128, 3)
self.conv_bn3_3 = ConvBN(128, 128, 3)
self.conv_bn4_1 = ConvBN(128, 64, 2)
self.conv_bn4_2 = ConvBN(128, 64, 3)
self.conv_bn4_3 = ConvBN(64, 64, 3)
self.conv_bn5_1 = ConvBN(64, 64, 2)
self.conv_bn5_2 = ConvBN(64, 64, 3)
self.conv_bn6 = ConvBN(64, output_channel, 1, use_activation=False)
def forward(self, x):
conv1_1 = self.conv_relu1(x) # 1-->64
conv1_2 = self.maxpool(conv1_1)
# conv2_x 1/4
conv2_1 = self.bottleneck_block_1t(conv1_2)
conv2_2 = self.bottleneck_block_1(conv2_1)
conv2_3 = self.bottleneck_block_1(conv2_2) # 128-->128
# conv3_x 1/8
conv3_1 = self.bottleneck_block_2t(conv2_3)
conv3_2 = self.bottleneck_block_2(conv3_1)
conv3_3 = self.bottleneck_block_2(conv3_2)
conv3_4 = self.bottleneck_block_2(conv3_3) # 256-->256
# conv4_x 1/16
conv4_1 = self.bottleneck_block_3t(conv3_4)
conv4_2 = self.bottleneck_block_3(conv4_1)
conv4_3 = self.bottleneck_block_3(conv4_2)
conv4_4 = self.bottleneck_block_3(conv4_3)
conv4_5 = self.bottleneck_block_3(conv4_4)
conv4_6 = self.bottleneck_block_3(conv4_5)
conv4_7 = self.bottleneck_block_3(conv4_6)
conv4_8 = self.bottleneck_block_3(conv4_7)
conv4_9 = self.bottleneck_block_3(conv4_8)
conv4_10 = self.bottleneck_block_3(conv4_9)
conv4_11 = self.bottleneck_block_3(conv4_10)
conv4_12 = self.bottleneck_block_3(conv4_11)
conv4_13 = self.bottleneck_block_3(conv4_12)
conv4_14 = self.bottleneck_block_3(conv4_13)
conv4_15 = self.bottleneck_block_3(conv4_14)
conv4_16 = self.bottleneck_block_3(conv4_15)
conv4_17 = self.bottleneck_block_3(conv4_16)
conv4_18 = self.bottleneck_block_3(conv4_17)
conv4_19 = self.bottleneck_block_3(conv4_18)
conv4_20 = self.bottleneck_block_3(conv4_19)
conv4_21 = self.bottleneck_block_3(conv4_20)
conv4_22 = self.bottleneck_block_3(conv4_21)
conv4_23 = self.bottleneck_block_3(conv4_22) # 512-->512
# conv5_x 1/32
conv5_1 = self.bottleneck_block_4t(conv4_23) # 512-->1024
conv5_2 = self.bottleneck_block_4(conv5_1)
conv5_3 = self.bottleneck_block_4(conv5_2) # 1024-->1024
up6 = self.conv_bn1_1(nn.Upsample(scale_factor=2)(conv5_3))
merge6 = torch.cat([conv4_23, up6], dim=1)
conv6 = self.conv_bn1_2(merge6)
conv6 = self.conv_bn1_3(conv6) # 512-->512
up7 = self.conv_bn2_1(nn.Upsample(scale_factor=2)(conv6)) # 512-->512
merge7 = torch.cat([conv3_4, up7], dim=1)
conv7 = self.conv_bn2_2(merge7)
conv7 = self.conv_bn2_3(conv7)
up8 = self.conv_bn3_1(nn.Upsample(scale_factor=2)(conv7)) # 256-->256
merge8 = torch.cat([conv2_3, up8], dim=1)
conv8 = self.conv_bn3_2(merge8)
conv8 = self.conv_bn3_3(conv8)
up9 = self.conv_bn4_1(nn.Upsample(scale_factor=2)(conv8)) # 128-->64
merge9 = torch.cat([conv1_1, up9], dim=1)
conv9 = self.conv_bn4_2(merge9)
conv9 = self.conv_bn4_3(conv9)
up10 = self.conv_bn5_1(nn.Upsample(scale_factor=2)(conv9)) # 64-->64
conv10 = self.conv_bn5_2(up10)
conv10 = self.conv_bn5_2(conv10)
conv11 = self.conv_bn6(conv10)
out = F.sigmoid(conv11)
return out
the train image size is 260260, it goes while. while I want to predict 1920019200 image, this is too large, the GPU reports out of memory. Is there any way to modify to fit large image prediction? change it to predict in cpu is also impossible, it reports segmentation fault.