I have a working code by img2img-turbo/src/train_cyclegan_turbo.py at main · GaParmar/img2img-turbo · GitHub. I tried adding a few lines of code which are highligted belo and after that I got the error. I know why I get the error, however, I couldnt figure out how to solve it. If I use retain_graph=True I get out of memory as expected. Is there any workaround ? I appreciate any suggestions. The code that produces the error is below.
for epoch in range(first_epoch, args.max_train_epochs):
for step, batch in enumerate(train_dataloader):
l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec, init_pred_net]
with accelerator.accumulate(*l_acc):
burst_img = batch["pixel_values_src"].to(dtype=weight_dtype)
print("burst_img shape", burst_img.shape, flush=True)
# ADDED BY ME
img_a = init_pred_net(burst_img) # initial prediction
# ADDED BY ME
del burst_img
print("img_pred shape after init pred", img_a.shape, flush=True)
img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
print("img_b shape", img_b.shape, flush=True)
bsz = img_a.shape[0]
fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz, device=img_a.device).long()
"""
Cycle Objective
"""
# A -> fake B -> rec A
cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle
loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips
# B -> fake A -> rec B
cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle
loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips
# ADDED BY ME
loss_init_pred = crit_init_pred(img_a, img_b)
# MODIFIED BY ME
accelerator.backward(loss_cycle_a + loss_cycle_b + loss_init_pred, retain_graph=False)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
optimizer_gen.step()
lr_scheduler_gen.step()
optimizer_gen.zero_grad()
"""
Generator Objective (GAN) for task a->b and b->a (fake inputs)
"""
fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan
loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan
accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
optimizer_gen.step()
lr_scheduler_gen.step()
optimizer_gen.zero_grad()
optimizer_disc.zero_grad()
"""
Identity Objective
"""
idt_a = CycleGAN_Turbo.forward_with_networks(img_b, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt
loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips
idt_b = CycleGAN_Turbo.forward_with_networks(img_a, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt
loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips
loss_g_idt = loss_idt_a + loss_idt_b
accelerator.backward(loss_g_idt, retain_graph=False)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
optimizer_gen.step()
lr_scheduler_gen.step()
optimizer_gen.zero_grad()
"""
Discriminator for task a->b and b->a (fake inputs)
"""
loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan
loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan
loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5
accelerator.backward(loss_D_fake, retain_graph=False)
if accelerator.sync_gradients:
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer_disc.step()
lr_scheduler_disc.step()
optimizer_disc.zero_grad()
"""
Discriminator for task a->b and b->a (real inputs)
"""
loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan
loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan
loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5
accelerator.backward(loss_D_real, retain_graph=False)
if accelerator.sync_gradients:
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer_disc.step()
lr_scheduler_disc.step()
optimizer_disc.zero_grad()