You can use transforms.Lambda
to call the functional API:
transform=torchvision.transforms.Compose([
torchvision.transforms.CenterCrop((80, 80)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=1)),
])
img = transforms.ToPILImage()(torch.randn(3, 224, 224))
out = transform(img)