i guess i’m pretty noob. here is my train loop body
for i, (lr_image, hr_image) in enumerate(train_bar):
start = time.time()
batch_size = lr_image.shape[0]
if torch.cuda.is_available():
lr_image = lr_image.cuda()
hr_image = hr_image.cuda()
d_lr = adjust_learning_rate(
l.optimizerD, epoch, i, l.train_loader.size, config.d_lr
)
g_lr = adjust_learning_rate(
l.optimizerG, epoch, i, l.train_loader.size, config.g_lr
)
if config.prof and i > 10:
break
############################
# (1) Update D network
##########################
l.netD.zero_grad()
fake_img = l.netG(lr_image)
real_out = l.netD(hr_image).mean()
fake_out = l.netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True)
l.optimizerD.step()
############################
# (2) Update G network
###########################
l.netG.zero_grad()
g_loss = l.generator_loss(fake_out, fake_img, hr_image)
g_loss.backward()
l.optimizerG.step()
i got it from https://github.com/leftthomas/SRGAN/blob/master/train.py. I realize that this might be wrong because fake_img
isn’t detached on being passed to the discriminator. according to this post (i think) When training GAN why do we not need to zero_grad discriminator? i might be doing the wrong thing. also the loss function for the discriminator is a little wacky but it makes sense to me in a way (if real_out == 1 then low penalty, if fake_out == 1 then high penalty).