Severe training slowdown after unfold-based patchification

Hi everyone,

I’m training a CycleGAN and I’m facing a severe slowdown / apparent freeze when I introduce explicit patchification using torch.nn.functional.unfold.

In my dataset I:

  • Apply aligned transforms to paired images.
  • Convert images to tensors.
  • Split them into patches using unfold.
  • Return a tensor shaped like (batch, n_patches, C, P, P).
    def _aligned_transforms_pipeline(self,
                                     x: Image.Image,
                                     y: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
        self._sanitize_preprocess_params()

        if "resize" in self._preprocess:
            if "patchify" in self._preprocess:
                # Ensure size is multiple of patch size
                size = (int(self._load_size / self._patch_size) * self._patch_size,
                        int(self._load_size / self._patch_size) * self._patch_size)
            else:
                # If size is an int, then the shape will be (size * height / width, size)
                size = self._load_size if self._maintain_aspect_ratio else (self._load_size, self._load_size)
            # Resize
            x = functional.resize(img=x, size=size, interpolation=self._interpolation, max_size=None, antialias=True)
            y = functional.resize(img=y, size=size, interpolation=self._interpolation, max_size=None, antialias=True)

        if "crop" in self._preprocess:
            if self._augmentation:
                # Random crop
                i, j, h, w = v1.RandomCrop.get_params(img=x, output_size=(self._crop_size, self._crop_size))
                x = functional.crop(img=x, top=i, left=j, height=h, width=w)
                y = functional.crop(img=y, top=i, left=j, height=h, width=w)
            else:
                x = functional.center_crop(img=x, output_size=[self._crop_size, self._crop_size])
                y = functional.center_crop(img=y, output_size=[self._crop_size, self._crop_size])

        if self._augmentation:
            if "hflip" in self._preprocess:
                # Random horizontal flipping
                if random.random() > 0.5:
                    x = functional.hflip(img=x)
                    y = functional.hflip(img=y)
            if "vflip" in self._preprocess:
                # Random vertical flipping
                if random.random() > 0.5:
                    x = functional.vflip(img=x)
                    y = functional.vflip(img=y)

        # Transform to tensor [0.0, 1.0]
        x = functional.to_tensor(pic=x)
        y = functional.to_tensor(pic=y)

        if self._normalize:
            x = functional.normalize(tensor=x, mean=self._mean, std=self._std)
            y = functional.normalize(tensor=y, mean=self._mean, std=self._std)

        if "patchify" in self._preprocess:
            # L = number of patches, C = channels, P = patch size
            C = x.size(0)
            p = torch.nn.functional.unfold(input=x, kernel_size=self._patch_size, stride=self._patch_stride)
            x = p.view(C, self._patch_size, self._patch_size, -1).permute(3, 0, 1, 2)  # (L, C, P, P)

            C = y.size(0)
            p = torch.nn.functional.unfold(input=y, kernel_size=self._patch_size, stride=self._patch_stride)
            y = p.view(C, self._patch_size, self._patch_size, -1).permute(3, 0, 1, 2)  # (L, C, P, P)

        return x, y

In the training loop I then collapse patches into the batch dimension:

real_A = real_A.view(-1, C, P, P)
real_B = real_B.view(-1, C, P, P)

The slowdown appears during the generator loss computation. Without patchification everything runs normally.

My setup is roughly the following:

  • Batch size = 1.
  • Input image size = 1024 × 1024.
  • Image is split into 4 patches of size 512 × 512.
  • Patches are collapsed into the batch dimension, so the effective input to the model is: (4, 3, 512, 512).

What are the most common causes of this kind of slowdown when using unfold-based patchification? How can I solve this problem?

Any insight or best practice would be greatly appreciated.

Thanks!

Materializing tensors with broadcasted dimensions is often the root cause for unexpected slowdowns. Check if this is the case as you should also see a large increase in memory usage.

Yes, I can confirm that the memory usage saturates the full 24 GB of my NVIDIA RTX 4090. How can I address this issue? As shown in the code above, my goal is to split an image into multiple patches.

Try to narrow down which line of code increases the memory usage via broadcasting and check if this is really what the code should do as sometimes this broadcasting was done my mistake.

Up to this point everything looks fine in terms of memory usage:

real_A = inputs.to(device=self._device)
real_B = targets.to(device=self._device)

if "patchify" in self._opt.preprocess:
    # Collapse "batch_size * n_patches" into the batch dimension
    real_A = real_A.view(-1, real_A.size(2), real_A.size(3), real_A.size(4))
    real_B = real_B.view(-1, real_B.size(2), real_B.size(3), real_B.size(4))

inputs / targets start with shape (1, 4, 3, 512, 512) and after the view they become (4, 3, 512, 512). Memory usage is stable and as expected up to here.

The saturation happens after the forward pass:

fake_B, rec_A, fake_A, rec_B = self._forward_pass(real_A, real_B)

where:

  • fake_B: (4, 3, 512, 512)
  • rec_A: (4, 3, 512, 512)
  • fake_A: (4, 3, 512, 512)
  • rec_B: (4, 3, 512, 512)

So at this stage, in total, there should be 6 tensors of shape (4, 3, 512, 512) in memory (real_A, real_B + 4 outputs).

The forward pass is simply:

def _forward_pass(self, real_A, real_B):
    fake_B = self._netG_AB(real_A)
    rec_A  = self._netG_BA(fake_B)
    fake_A = self._netG_BA(real_B)
    rec_B  = self._netG_AB(fake_A)
    return fake_B, rec_A, fake_A, rec_B

After each forward pass, GPU memory usage keeps increasing (see attached image).

At this point it’s not clear to me whether this is simply insufficient VRAM (24 GB) for the workload, or whether the forward pass is doing something wrong.

Let me know if it would help to share more parts of the code.

Any suggestion would be appreciated.

@ptrblck Do you have any suggestions?