This is my collate function for doing what’s in the description. Note that the input images are different sizes, which is the whole point of padding.
def collate_fn(batch):
max_h = max([img.shape[-2] for img in batch])
max_w = max([img.shape[-1] for img in batch])
# pad all images in batch
for i, img in enumerate(batch):
pad_left, pad_right, pad_top, pad_bottom = 0, 0, 0, 0
if (diff_w := max_w - img.shape[-1]) > 0:
pad_left = diff_w//2
pad_right = diff_w - pad_left
if (diff_h := max_h - img.shape[-2]) > 0:
pad_top = diff_h//2
pad_bottom = diff_h - pad_top
if any([pad_left, pad_right, pad_top, pad_bottom]):
batch[i] = torch.nn.functional.pad(img,
(pad_left, pad_right, pad_top, pad_bottom))
return torch.stack(batch, axis=0)
Wondering if there are any neat tricks or obvious utilities to speed this up.