This post saved my day, thanks! For simplicity, forget about the odd/even value of the image width/height is OK, the simplified code.
import torchvision.transforms.functional as F
class SquarePad:
def __call__(self, image):
w, h = image.size
max_wh = np.max([w, h])
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (hp, vp, hp, vp)
return F.pad(image, padding, 0, 'constant')
# now use it as the replacement of transforms.Pad class
transform=transforms.Compose([
SquarePad(),
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])