I have a UNet++ and I’m trying to replace the up-sample operations with a transposed conventional operation. Is there a more straightforward way than defining every new conventional operations within the UNetPP class object?
class blockUNetPP(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
class UNetPP(nn.Module):
def __init__(self, channelExponent=6, dropout=0.,input_channels=3):
super().__init__()
channels = int(2 ** channelExponent + 0.5)
self.up = nn.Upsample(scale_factor=2)
self.conv0_0 = blockUNetPP(input_channels, channels, channels)
self.bbconv1_0 = nn.Conv2d(channels, channels, 4, 2, 1)
self.conv1_0 = blockUNetPP(channels, channels*2, channels*2)
self.bbconv2_0 = nn.Conv2d(channels*2, channels*2, 4, 2, 1)
self.conv2_0 = blockUNetPP(channels*2, channels*4, channels*4)
self.bbconv3_0 = nn.Conv2d(channels*4, channels*4, 4, 2, 1)
self.conv3_0 = blockUNetPP(channels*4, channels*8, channels*8)
self.bbconv4_0 = nn.Conv2d(channels*8, channels*8, 4, 2, 1)
self.conv4_0 = blockUNetPP(channels*8, channels*16, channels*16)
self.conv0_1 = blockUNetPP(channels+channels*2, channels, channels)
self.conv1_1 = blockUNetPP(channels*2+channels*4, channels*2, channels*2)
self.conv2_1 = blockUNetPP(channels*4+channels*8, channels*4, channels*4)
self.conv3_1 = blockUNetPP(channels*8+channels*16, channels*8, channels*8)
self.conv0_2 = blockUNetPP(channels*2+channels*2, channels, channels)
self.conv1_2 = blockUNetPP(channels*2*2+channels*4, channels*2, channels*2)
self.conv2_2 = blockUNetPP(channels*4*2+channels*8, channels*4, channels*4)
self.conv0_3 = blockUNetPP(channels*3+channels*2, channels, channels)
self.conv1_3 = blockUNetPP(channels*2*3+channels*4, channels*2, channels*2)
self.conv0_4 = blockUNetPP(channels*4+channels*2, channels, channels)
self.final = nn.Conv2d(channels, 3, kernel_size=1)
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.bbconv1_0(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.bbconv2_0(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.bbconv3_0(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.bbconv4_0(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
output = self.final(x0_4)
return output