I am trying to design a mirrored autoencoder for greyscale images (binary masks) of 512 x 512, as described in section 3.1 of the following paper. However, when I run the model and the output is passed into the loss function - the tensor sizes are different (tensor a is of size 510 and tensor b is of size 512).
Before encoding my shape is: [2, 1, 512, 512]
After encoding my shape is: [2, 32, 14, 14]
After decoding my shape is: [2, 1, 510, 510]
I’m confused as to why the shape is now 510 x 510 instead of 512 x 512? I am also not sure my implementation follows the paper exactly.
Any help would be appreciated, thanks.
class AutoEncoderConv(nn.Module):
def __init__(self):
super(AutoEncoderConv, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=3),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=3),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=3),
nn.ReLU(True),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=3),
nn.ReLU(True),
nn.MaxPool2d(2)
)
self.decoder = nn.Sequential(
Interpolate(mode='bilinear', scale_factor=2),
nn.ConvTranspose2d(32, 32, kernel_size=3),
nn.ReLU(True),
Interpolate(mode='bilinear', scale_factor=2),
nn.ConvTranspose2d(32, 32, kernel_size=3),
nn.ReLU(True),
Interpolate(mode='bilinear', scale_factor=2),
nn.ConvTranspose2d(32, 32, kernel_size=3),
nn.ReLU(True),
Interpolate(mode='bilinear', scale_factor=2),
nn.ConvTranspose2d(32, 32, kernel_size=3),
nn.ReLU(True),
Interpolate(mode='bilinear', scale_factor=2),
nn.ConvTranspose2d(32, 1, kernel_size=3),
nn.ReLU(True),
nn.Sigmoid()
)
def forward(self, x):
print()
print("Start Encode: ", x.shape)
x = self.encoder(x)
print("Finished Encode: ", x.shape)
x = self.decoder(x)
print("Finished Decode: ", x.shape)
return x
EDIT
The shapes between each layer are the following:
START: torch.Size([2, 1, 512, 512])
E1: torch.Size([2, 32, 255, 255])
E2: torch.Size([2, 32, 126, 126])
E3: torch.Size([2, 32, 62, 62])
E4: torch.Size([2, 32, 30, 30])
E5: torch.Size([2, 32, 14, 14])
D1: torch.Size([2, 32, 30, 30])
D2: torch.Size([2, 32, 62, 62])
D3: torch.Size([2, 32, 126, 126])
D4: torch.Size([2, 32, 254, 254])
D5: torch.Size([2, 1, 510, 510])