I’m training an ACGAN on cifar-10. The network recognizes fake image classes with an accuracy of 100% after only 50 steps which indicates that something goes wrong. I am however not sure what my mistake is.
training log:
epoch: 1/100 batch: 50/500 G loss: 0.8187 D loss: 4.1763 Loss cls fake: 0.0740 Loss cls real: 1.5676 fake acc: 100.0% real acc: 40.0%
epoch: 1/100 batch: 100/500 G loss: 0.9943 D loss: 3.8245 Loss cls fake: 0.1135 Loss cls real: 1.4479 fake acc: 100.0% real acc: 45.0%
training loop:
for epoch in range(num_epochs):
for batch_idx, (image, target) in enumerate(train_loader):
images = image.to(device)
target = target.to(device)
current_batchSize = images.size()[0]
realLabel = torch.ones(current_batchSize).to(device)
fakeLabel = torch.zeros(current_batchSize).to(device)
###########
# TRAIN D #
###########
optimizerD.zero_grad()
# On real data
predictR, predictRLabel = D(images)
loss_real_adv = criterion_adv(predictR, realLabel)
loss_real_aux = criterion_aux(predictRLabel, target)
real_cls_acc = compute_cls_acc(predictRLabel, target)
real_score = predictR
# On fake data
latent_value = torch.randn(current_batchSize, latent_size).to(device)
cls_one_hot = torch.zeros(current_batchSize, n_classes, device=device)
cls_one_hot[torch.arange(current_batchSize), target] = 1.0
latent = torch.cat((latent_value, cls_one_hot), dim=1)
fake_images = G(latent)
predictF, predictFLabel = D(fake_images.detach())
loss_fake_adv = criterion_adv(predictF, fakeLabel)
loss_fake_aux = criterion_aux(predictFLabel, target)
fake_cls_acc = compute_cls_acc(predictFLabel, target)
fake_score = predictF
loss_adv = loss_real_adv + loss_fake_adv
loss_aux = 1.8*loss_real_aux + 0.2*loss_fake_aux
lossD = loss_adv + loss_aux
lossD.backward()
optimizerD.step()
###########
# TRAIN G #
###########
optimizerG.zero_grad()
predictG, predictLabel = D(fake_images)
lossG_adv = criterion_adv(predictG, realLabel)
lossG_aux = criterion_aux(predictLabel, target)
lossG = lossG_adv + lossG_aux
lossG.backward()
optimizerG.step()
if (batch_idx+1) % 50 == 0:
print('epoch: {}/{} batch: {}/{} G loss: {:.4f} D loss: {:.4f} Loss cls fake: {:.4f} Loss cls real: {:.4f} fake acc: {}% real acc: {}%'.format(epoch+1, num_epochs, batch_idx+1, total_step, lossG.item(), lossD.item(), loss_fake_aux, loss_real_aux, fake_cls_acc, real_cls_acc))
Thankful for any help!