I’m trying to implement a model which is similar to the Unet based the attached architecture. (Supplementary materials for:
DeepLearningforSegmentationusinganOpenLarge-ScaleDatasetin2DEchocardiography)
I used this implementation but I did change it based on the attached pic.
here is the Unet 2 implementation:
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, stride, droup_rate = False):
super(BaseConv, self).__init__()
self.act = nn.ReLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding, stride)
self.b1 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding, stride)
self.b2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True)
def forward(self, x):
x = self.conv1(x)
x = self.act(self.b1(x))
x = self.conv2(x)
x = self.act(self.b2(x))
return x
class DownConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
super(DownConv, self).__init__()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv_block = BaseConv(in_channels, out_channels, kernel_size, padding, stride)
def forward(self, x):
x = self.pool1(x)
x = self.conv_block(x)
return x
class UpConv(nn.Module):
def __init__ (self, in_channels, in_channels_skip, out_channels, kernel_size, padding, stride):
super(UpConv, self).__init__()
self.act = nn.ReLU()
self.conv_trans1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, padding=0, stride=2)
self.b3 = nn.BatchNorm2d(in_channels, eps=1e-05, momentum=0.1, affine=True)
self.conv3 = nn.Conv2d(in_channels=in_channels + in_channels_skip, out_channels= out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
self.b4 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True)
self.conv4 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
self.b5 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True)
def forward(self, x, x_skip):
x = self.conv_trans1(x)
x = self.act(self.b3(x))
x = self.conv3(x)
x = torch.cat((x, x_skip), dim=1)
x = self.act(self.b4(x))
x = self.conv4(x)
x = self.act(self.b5(x))
return x
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, n_class, kernel_size, padding, stride, droup_rate = False):
super(UNet, self).__init__()
self.down1 = DownConv(in_channels, out_channels, kernel_size, padding, stride)#48
self.down2 = DownConv(out_channels, 2 * out_channels, kernel_size, padding, stride)#96
self.down3 = DownConv(2 * out_channels, 4 * out_channels, kernel_size, padding, stride)#192
self.down4 = DownConv(4 * out_channels, 8 * out_channels, kernel_size, padding, stride)#384
self.down5 = BaseConv(8 * out_channels, 16 * out_channels, kernel_size, padding, stride)#768
self.up4 = UpConv(16 * out_channels, 8 * out_channels, 8 * out_channels, kernel_size, padding, stride)
self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels, kernel_size, padding, stride)
self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels, kernel_size, padding, stride)
self.up1 = UpConv(2 * out_channels, out_channels, out_channels, kernel_size, padding, stride)
self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride)
def forward(self, x):
# Encoder
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
# Decoder
x_up1 = self.up4(x5, x4)
x_up2 = self.up3(x_up1, x3)
x_up3 = self.up2(x_up2, x2)
x_up4 = self.up1(x_up3, x1)
x_out = F.log_softmax(self.out(x_up4), 1)
print(x_out.size())
return x_out
model = UNet(in_channels=1,
out_channels=48,
n_class=2,
kernel_size=3,
padding=1,
stride=1)
model = model.to(device)
#print(model)
print("UNet model created")
#Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
Currently, x_up1
causing an error RuntimeError: Given groups=1, weight of size [384, 1152, 3, 3], expected input[1, 768, 30, 40] to have 1152 channels, but got 768 channels instead
I’m not sure where I am doing wrong. any comments would be appreciated.