Patch wise training with high resolution images

I got high resolution images about 5000x5000 of histology slides. I want to train a cycleGAN (or others) to generate images from one domain to another. Obviously, feeding the whole image into the network is impossible without downscaling to rougly 400x400 pixels. Therefore I want to train the network on cropped patches.

For now I want to use FiveCrop to get the crops

transform = [
    transforms.Resize(int(5000 * 1.12), Image.BICUBIC),
    transforms.FiveCrop(opt.img_size),
    transforms.Lambda(lambda crops: torch.stack(
        [transforms.ToTensor()(crop) for crop in crops])),
    transforms.Lambda(lambda tensors: torch.stack(
        [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tensor) for tensor in tensors]))
]

Further, on the batch level I split the batches into patches and feed these into the network. However, after the second iteration of the inner loop I get CUDA mem allocation errors.

prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):
        assert batch['from'].size() == batch['to'].size()

        bs, ncrops, channels, height, width = batch['from'].size()

        loss_G = 0

        for patch_from, patch_to in zip(batch['from'].view(-1, bs, channels, height, width),
                                        batch['to'].view(-1, bs, channels, height, width)):
            real_from = Tensor(patch_from.cuda())
            real_to = Tensor(patch_to.cuda())

            # Adversarial ground truths
            valid = Tensor(
                np.ones((real_from.size(0), *D_A.output_shape)))
            fake = Tensor(
                np.zeros((real_from.size(0), *D_A.output_shape)))

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_from), real_from)
            loss_id_B = criterion_identity(G_AB(real_to), real_to)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_to = G_AB(real_from)
            loss_GAN_AB = criterion_GAN(D_B(fake_to), valid)
            fake_from = G_BA(real_to)
            loss_GAN_BA = criterion_GAN(D_A(fake_from), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_to)
            loss_cycle_A = criterion_cycle(recov_A, real_from)
            recov_B = G_AB(fake_from)
            loss_cycle_B = criterion_cycle(recov_B, real_to)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_G + loss_GAN + opt.lambda_cyc * \
                loss_cycle + opt.lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

Can someone explain how I’m able to feed patches to the NN? Further any other recommendation, how to train GANs on high resolution images?

greets