Hi everyone,
I’m training a CycleGAN and I’m facing a severe slowdown / apparent freeze when I introduce explicit patchification using torch.nn.functional.unfold.
In my dataset I:
- Apply aligned transforms to paired images.
- Convert images to tensors.
- Split them into patches using
unfold. - Return a tensor shaped like
(batch, n_patches, C, P, P).
def _aligned_transforms_pipeline(self,
x: Image.Image,
y: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
self._sanitize_preprocess_params()
if "resize" in self._preprocess:
if "patchify" in self._preprocess:
# Ensure size is multiple of patch size
size = (int(self._load_size / self._patch_size) * self._patch_size,
int(self._load_size / self._patch_size) * self._patch_size)
else:
# If size is an int, then the shape will be (size * height / width, size)
size = self._load_size if self._maintain_aspect_ratio else (self._load_size, self._load_size)
# Resize
x = functional.resize(img=x, size=size, interpolation=self._interpolation, max_size=None, antialias=True)
y = functional.resize(img=y, size=size, interpolation=self._interpolation, max_size=None, antialias=True)
if "crop" in self._preprocess:
if self._augmentation:
# Random crop
i, j, h, w = v1.RandomCrop.get_params(img=x, output_size=(self._crop_size, self._crop_size))
x = functional.crop(img=x, top=i, left=j, height=h, width=w)
y = functional.crop(img=y, top=i, left=j, height=h, width=w)
else:
x = functional.center_crop(img=x, output_size=[self._crop_size, self._crop_size])
y = functional.center_crop(img=y, output_size=[self._crop_size, self._crop_size])
if self._augmentation:
if "hflip" in self._preprocess:
# Random horizontal flipping
if random.random() > 0.5:
x = functional.hflip(img=x)
y = functional.hflip(img=y)
if "vflip" in self._preprocess:
# Random vertical flipping
if random.random() > 0.5:
x = functional.vflip(img=x)
y = functional.vflip(img=y)
# Transform to tensor [0.0, 1.0]
x = functional.to_tensor(pic=x)
y = functional.to_tensor(pic=y)
if self._normalize:
x = functional.normalize(tensor=x, mean=self._mean, std=self._std)
y = functional.normalize(tensor=y, mean=self._mean, std=self._std)
if "patchify" in self._preprocess:
# L = number of patches, C = channels, P = patch size
C = x.size(0)
p = torch.nn.functional.unfold(input=x, kernel_size=self._patch_size, stride=self._patch_stride)
x = p.view(C, self._patch_size, self._patch_size, -1).permute(3, 0, 1, 2) # (L, C, P, P)
C = y.size(0)
p = torch.nn.functional.unfold(input=y, kernel_size=self._patch_size, stride=self._patch_stride)
y = p.view(C, self._patch_size, self._patch_size, -1).permute(3, 0, 1, 2) # (L, C, P, P)
return x, y
In the training loop I then collapse patches into the batch dimension:
real_A = real_A.view(-1, C, P, P)
real_B = real_B.view(-1, C, P, P)
The slowdown appears during the generator loss computation. Without patchification everything runs normally.
My setup is roughly the following:
- Batch size = 1.
- Input image size = 1024 × 1024.
- Image is split into 4 patches of size 512 × 512.
- Patches are collapsed into the batch dimension, so the effective input to the model is: (4, 3, 512, 512).
What are the most common causes of this kind of slowdown when using unfold-based patchification? How can I solve this problem?
Any insight or best practice would be greatly appreciated.
Thanks!
