Can't recover input shape after going through my FCN

Hello !

I’m trying to implement a Fully Convolutional Network, based on a VGG16 encoder. I’ve set up the architecture for the decoder. When the input tensor has a shape of multiples of 224, the decoder’s output shape is fine. However, when it is not the case, for instance the image shape is (500,500) or (1080,720), I can’t recover the shape.

Here’s what I have so far:

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torchvision import models 


class FCN(nn.Module):

	def __init__(self, nb_classes): 

		super().__init__()

		vgg = models.vgg16()
		self.encoder = vgg.features

		self.relu = nn.ReLU(inplace = True)
		self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
		self.bn1     = nn.BatchNorm2d(512)
		self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
		self.bn2     = nn.BatchNorm2d(256)
		self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
		self.bn3     = nn.BatchNorm2d(128)
		self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
		self.bn4     = nn.BatchNorm2d(64)
		self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
		self.bn5     = nn.BatchNorm2d(32)
		self.classifier = nn.Conv2d(32, nb_classes, kernel_size=1) 


	def forward(self, x): 

		outputs = {}
		for i,l in enumerate(self.encoder): 
			x = l(x)

			if isinstance(l, nn.Conv2d): 
				print(x.shape)
			
		print('Starting decoder')

		score = self.relu(self.deconv1(x))

		print(score.shape)

		score = self.bn1(score)
		score = self.relu(self.deconv2(score))
		
		print(score.shape)

		score = self.bn2(score)
		score = self.relu(self.deconv3(score))

		print(score.shape)

		score = self.bn3(score)
		score = self.relu(self.deconv4(score))

		print(score.shape)
		score = self.bn4(score)
		score = self.bn5(self.relu(self.deconv5(score)))

		print(score.shape)
		return self.classifier(score)


fcn = FCN(2)

img = torch.rand(1,3,500,500)
out = fcn(img)

How should I change the deconv layers to make this work ?

Thanks !