I got high resolution images about 5000x5000 of histology slides. I want to train a cycleGAN (or others) to generate images from one domain to another. Obviously, feeding the whole image into the network is impossible without downscaling to rougly 400x400 pixels. Therefore I want to train the network on cropped patches.
For now I want to use FiveCrop to get the crops
transform = [
transforms.Resize(int(5000 * 1.12), Image.BICUBIC),
transforms.FiveCrop(opt.img_size),
transforms.Lambda(lambda crops: torch.stack(
[transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda tensors: torch.stack(
[transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tensor) for tensor in tensors]))
]
Further, on the batch level I split the batches into patches and feed these into the network. However, after the second iteration of the inner loop I get CUDA mem allocation errors.
prev_time = time.time()
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
assert batch['from'].size() == batch['to'].size()
bs, ncrops, channels, height, width = batch['from'].size()
loss_G = 0
for patch_from, patch_to in zip(batch['from'].view(-1, bs, channels, height, width),
batch['to'].view(-1, bs, channels, height, width)):
real_from = Tensor(patch_from.cuda())
real_to = Tensor(patch_to.cuda())
# Adversarial ground truths
valid = Tensor(
np.ones((real_from.size(0), *D_A.output_shape)))
fake = Tensor(
np.zeros((real_from.size(0), *D_A.output_shape)))
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_from), real_from)
loss_id_B = criterion_identity(G_AB(real_to), real_to)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_to = G_AB(real_from)
loss_GAN_AB = criterion_GAN(D_B(fake_to), valid)
fake_from = G_BA(real_to)
loss_GAN_BA = criterion_GAN(D_A(fake_from), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_to)
loss_cycle_A = criterion_cycle(recov_A, real_from)
recov_B = G_AB(fake_from)
loss_cycle_B = criterion_cycle(recov_B, real_to)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_G + loss_GAN + opt.lambda_cyc * \
loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
Can someone explain how I’m able to feed patches to the NN? Further any other recommendation, how to train GANs on high resolution images?
greets