class UNetDown(nn.Module):
def init(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).init()
layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm2d(out_size))
layers.append(nn.ReLU(inplace=True))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def init(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).init()
layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm2d(out_size),
nn.ReLU(inplace=True)]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
class GeneratorUNet(nn.Module):
def init(self, in_channels=3, out_channels=1):
super(GeneratorUNet, self).init()
self.size1 = nn.AvgPool2d(1, stride=8)
self.size3 = nn.AvgPool2d(1, stride=4)
self.size4 = nn.AvgPool2d(1, stride=2)
self.convinput = nn.Conv2d(3, 64, kernel_size=3,padding=1, bias=False)
self.factor_in=FCANet(64,64)
self.conva=nn.Conv2d(3, 64, kernel_size=3,padding=1, bias=False)
self.factor_a= FCANet(64,64)
self.convb=nn.Conv2d(3, 64, kernel_size=3,padding=1, bias=False)
self.factor_b= FCANet(64,64)
self.convc=nn.Conv2d(3, 64, kernel_size=3,padding=1, bias=False)
self.factor_c=FCANet(64,64)
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.c2 = FCANet(128,128)
self.down3 = UNetDown(128, 256)
self.c3 = FCANet(256,256)
self.down4 = UNetDown(256, 512, dropout=0.0)
self.c4 = FCANet(512,512)
self.down5 = UNetDown(512, 512, dropout=0.0)
self.c5 = FCANet(512,512)
self.down6 = UNetDown(512, 512, dropout=0.0)
self.c6 = FCANet(512,512)
self.down7 = UNetDown(512, 512, dropout=0.0)
self.c7 = FCANet(512,512)
self.down8 = UNetDown(512, 512, normalize=False, dropout=0.0)
self.up1 = UNetUp(512, 512, dropout=0.0)
self.up2 = UNetUp(1024, 512, dropout=0.0)
self.up3 = UNetUp(1024, 512, dropout=0.0)
self.up4 = UNetUp(1024, 512, dropout=0.0)
self.up5 = UNetUp(1024, 256)
self.up6 = UNetUp(512, 128)
self.up7 = UNetUp(256, 64)
self.final = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(128, out_channels, 4, padding=1,),
nn.Dropout(0.0),
nn.Tanh()
)
def forward(self, input):
a=self.size1(input)
b=self.size3(input)
c=self.size4(input)
input= self.convinput(input)
input = self.factor_in(input)
a = self.conva(a)
a = self.factor_a(a)
a=nn.functional.interpolate(a, size=256, mode='bilinear', align_corners=True)
b = self.convb(b)
b = self.factor_b(b)
b=nn.functional.interpolate(b, size=256, mode='bilinear', align_corners=True)
c = self.convc(c)
c = self.factor_c(c)
c=nn.functional.interpolate(c, size=256, mode='bilinear', align_corners=True)
x=(input+a+b+c)
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8, d7)
u2 = self.up2(u1, d6)
u3 = self.up3(u2, d5)
u4 = self.up4(u3, d4)
u5 = self.up5(u4, d3)
u6 = self.up6(u5, d2)
u7 = self.up7(u6, d1)
return self.final(u7)
class FCANet(nn.Module):
def init(self, in_channels, out_channels):
super(FCANet, self).init()
inter_channels = in_channels // 4
self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.sa = non_bottleneck_1d(inter_channels, 0.3, 2)
self.sc = CAM_Module(inter_channels)
self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU())
self.conv8 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(inter_channels, out_channels, 1))
def forward(self, x):
feat1 = self.conv5a(x)
sa_feat = self.sa(feat1)
sa_conv = self.conv51(sa_feat)
feat2 = self.conv5c(x)
sc_feat = self.sc(feat2)
sc_conv = self.conv52(sc_feat)
feat_sum = 0.3*sa_conv+0.7*sc_conv
sasc_output = self.conv8(feat_sum)
return sasc_output