What's the fastest way to prepare a batch from different sized images by padding?

This is my collate function for doing what’s in the description. Note that the input images are different sizes, which is the whole point of padding.

def collate_fn(batch):
        max_h = max([img.shape[-2] for img in batch])
        max_w = max([img.shape[-1] for img in batch])
        # pad all images in batch
        for i, img in enumerate(batch):
            pad_left, pad_right, pad_top, pad_bottom = 0, 0, 0, 0
            if (diff_w := max_w - img.shape[-1]) > 0:
                pad_left = diff_w//2
                pad_right = diff_w - pad_left
            if (diff_h := max_h - img.shape[-2]) > 0:
                pad_top = diff_h//2
                pad_bottom = diff_h - pad_top
            if any([pad_left, pad_right, pad_top, pad_bottom]):
                batch[i] = torch.nn.functional.pad(img,
                                (pad_left, pad_right, pad_top, pad_bottom))
        return torch.stack(batch, axis=0)

Wondering if there are any neat tricks or obvious utilities to speed this up.