I got high resolution images about 5000x5000 of histology slides. I want to train a cycleGAN (or others) to generate images from one domain to another. Obviously, feeding the whole image into the network is impossible without downscaling to rougly 400x400 pixels. Therefore I want to train the network on cropped patches.
For now I want to use FiveCrop to get the crops
transform = [ transforms.Resize(int(5000 * 1.12), Image.BICUBIC), transforms.FiveCrop(opt.img_size), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda tensors: torch.stack( [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tensor) for tensor in tensors])) ]
Further, on the batch level I split the batches into patches and feed these into the network. However, after the second iteration of the inner loop I get CUDA mem allocation errors.
prev_time = time.time() for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): assert batch['from'].size() == batch['to'].size() bs, ncrops, channels, height, width = batch['from'].size() loss_G = 0 for patch_from, patch_to in zip(batch['from'].view(-1, bs, channels, height, width), batch['to'].view(-1, bs, channels, height, width)): real_from = Tensor(patch_from.cuda()) real_to = Tensor(patch_to.cuda()) # Adversarial ground truths valid = Tensor( np.ones((real_from.size(0), *D_A.output_shape))) fake = Tensor( np.zeros((real_from.size(0), *D_A.output_shape))) # ------------------ # Train Generators # ------------------ G_AB.train() G_BA.train() optimizer_G.zero_grad() # Identity loss loss_id_A = criterion_identity(G_BA(real_from), real_from) loss_id_B = criterion_identity(G_AB(real_to), real_to) loss_identity = (loss_id_A + loss_id_B) / 2 # GAN loss fake_to = G_AB(real_from) loss_GAN_AB = criterion_GAN(D_B(fake_to), valid) fake_from = G_BA(real_to) loss_GAN_BA = criterion_GAN(D_A(fake_from), valid) loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2 # Cycle loss recov_A = G_BA(fake_to) loss_cycle_A = criterion_cycle(recov_A, real_from) recov_B = G_AB(fake_from) loss_cycle_B = criterion_cycle(recov_B, real_to) loss_cycle = (loss_cycle_A + loss_cycle_B) / 2 # Total loss loss_G = loss_G + loss_GAN + opt.lambda_cyc * \ loss_cycle + opt.lambda_id * loss_identity loss_G.backward() optimizer_G.step()
Can someone explain how I’m able to feed patches to the NN? Further any other recommendation, how to train GANs on high resolution images?