Hello !
I’m trying to implement a Fully Convolutional Network, based on a VGG16 encoder. I’ve set up the architecture for the decoder. When the input tensor has a shape of multiples of 224, the decoder’s output shape is fine. However, when it is not the case, for instance the image shape is (500,500) or (1080,720), I can’t recover the shape.
Here’s what I have so far:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class FCN(nn.Module):
def __init__(self, nb_classes):
super().__init__()
vgg = models.vgg16()
self.encoder = vgg.features
self.relu = nn.ReLU(inplace = True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, nb_classes, kernel_size=1)
def forward(self, x):
outputs = {}
for i,l in enumerate(self.encoder):
x = l(x)
if isinstance(l, nn.Conv2d):
print(x.shape)
print('Starting decoder')
score = self.relu(self.deconv1(x))
print(score.shape)
score = self.bn1(score)
score = self.relu(self.deconv2(score))
print(score.shape)
score = self.bn2(score)
score = self.relu(self.deconv3(score))
print(score.shape)
score = self.bn3(score)
score = self.relu(self.deconv4(score))
print(score.shape)
score = self.bn4(score)
score = self.bn5(self.relu(self.deconv5(score)))
print(score.shape)
return self.classifier(score)
fcn = FCN(2)
img = torch.rand(1,3,500,500)
out = fcn(img)
How should I change the deconv layers to make this work ?
Thanks !