Hi,
I am trying to train a UNET model using a modified implementation of this code example from pytorch lightening.
I’m running into an issue (I believe it may be in the contraction block) where my kernel size is larger than the input image. By the end of the contraction, I end up with an image of size torch.Size([8,64,2,2])
, which is quite obviously too small to be convolved.
Here is the code for the Unet()
object:
class UNet(pl.LightningModule):
def __init__(self,
in_channels,
output_channels,
hidden_channels=64,
depth=3):
super(UNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
self.conv_final = nn.Conv2d(hidden_channels,
output_channels,
kernel_size=1)
self.depth = depth
self.contracting_layers = []
self.expanding_layers = []
for i in range(0, depth):
self.contracting_layers += [
ContractingBlock(hidden_channels * 2**i)
]
for i in range(1, depth + 1):
self.expanding_layers += [ExpandingBlock(hidden_channels * 2**i)]
self.contracting_layers = nn.ModuleList(self.contracting_layers)
self.expanding_layers = nn.ModuleList(self.expanding_layers)
def forward(self, x):
depth = self.depth
contractive_x = []
x = self.conv1(x)
contractive_x.append(x)
for i in range(depth):
x = self.contracting_layers[i](x)
contractive_x.append(x)
for i in range(depth - 1, -1, -1):
x = self.expanding_layers[i](x, contractive_x[i])
x = self.conv_final(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch['image'], batch['mask']
y_pred = self.forward(x)
loss = criterion(y_pred, y)
self.log('loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=lr)
Which relies on the following DoubleConv
, Contracting
, UpsampleConv
, and Expanding
objects:
Other blocks
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3)),
nn.ReLU(inplace=True),
PrintLayer(),
nn.Conv2d(out_channels, out_channels, (3, 3)),
nn.ReLU(inplace=True),
PrintLayer()
)
def forward(self, x):
return self.net(x)
class ContractingBlock(nn.Module):
def __init__(self, in_channels):
super(ContractingBlock, self).__init__()
# first a conv (3x3, no padding), relu, conv 3x3, relu, max_pool (2x2, stride 2)
self.double_conv = DoubleConv(in_channels, in_channels * 2)
self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.double_conv(x)
x = self.pooling(x)
return x
class UpsampleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3):
super(UpsampleConv, self).__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
PrintLayer(),
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size),
PrintLayer())
def forward(self, x):
return self.net(x)
class ExpandingBlock(nn.Module):
def __init__(self, in_channels):
super(ExpandingBlock, self).__init__()
self.upsample = UpsampleConv(in_channels, in_channels // 2)
self.double_conv = DoubleConv(in_channels, in_channels // 2)
def forward(self, x, skip_conn):
x = self.upsample(x)
# crop skip_conn and add to upsampled x
cropped_skip_conn = crop(skip_conn, x.shape)
x = torch.cat([cropped_skip_conn, x], axis=1)
x = self.double_conv(x)
return x
Thanks to @PA_Nik, I threw in a custom PrintLayer
, which allows me to examine the input dimensions which are as follows:
torch.Size([8, 128, 62, 62])
torch.Size([8, 128, 60, 60])
torch.Size([8, 256, 28, 28])
torch.Size([8, 256, 26, 26])
torch.Size([8, 512, 11, 11])
torch.Size([8, 512, 9, 9])
torch.Size([8, 512, 10, 10])
torch.Size([8, 256, 8, 8])
torch.Size([8, 256, 6, 6])
torch.Size([8, 256, 4, 4])
torch.Size([8, 256, 8, 8])
torch.Size([8, 128, 6, 6])
torch.Size([8, 128, 4, 4])
torch.Size([8, 128, 2, 2])
torch.Size([8, 128, 4, 4])
torch.Size([8, 64, 2, 2])
Could somebody please guide me through the process by which I can add the requisite paddings or contraction procedures to keep the dimensionality of my network from contracting too far?
Is there an easy way to check which part of the network is contracting too far?