I’m making a GAN. The generator: G takes in MNIST data: x, and outputs adversarial perturbations G(x) such that a target network: F classifies G(x) + x incorrectly. (fooling the target network with an adversarial example).
The discriminator D takes in real images: x and adversarial examples: G(x) + x. D tries to output a 1 if an image is real: x and 0 if it is fake: G(x) + x.
G’s parameters should minimize the loss: G_Loss= log(1 - D(x+G(x)))
D’s parameters should minimize the loss: D_Loss = log(1 - D(x)) + D(x+G(x))
G’s parameters should also maximize the loss: F_Loss = CE_Loss(F(x + G(x)), t) (where t = target classes)
Here is the code:
#hypers
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
img_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 5 #50
# initialize the models
D = Discriminator(img_dim).to(device)
G = Generator(img_dim, img_dim).to(device)
F = TargetA().to(device)
PATH = "./trained_models/A/Trained_model_A"
F.load_state_dict(torch.load(PATH))
F.eval()
#set up the MNIST data
transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
dataset = datasets.MNIST(root="./data", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# set up the optimizers and loss for the models
opt_disc = optim.Adam(D.parameters(), lr=lr)
opt_gen = optim.Adam(G.parameters(), lr=lr)
CE_loss = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for batch_idx, (real, labels) in enumerate(loader):
#get a fixed input batch to display gen output
if batch_idx == 0:
if epoch == 0:
fixed_input = real.view(-1,784).to(device)
adv_ex = real.clone().reshape(-1,784).to(device) # [32, 784] advex copy of first batch flattened
real = real.view(-1, 784).to(device) # [32, 784] # real batch flattened
labels = labels.to(device) # size() [32] 32 labels in batch
#purturb each image in adv_ex
tmp_adv_ex = []
for idx, item in enumerate(adv_ex):
purturbation = G(adv_ex[idx])
tmp_adv_ex.append(adv_ex[idx] + purturbation)
adv_ex = torch.cat(tmp_adv_ex, dim=0)
# Train Generator: min log(1 - D(G(z))) <-> max F_Loss
adv_ex = adv_ex.reshape(32, 28*28)
D_out = D(adv_ex) #discriminator decides if advex is real or fake
G_Loss = torch.mean(torch.log(1. - D_out)) #get loss for gen's desired desc pred
adv_ex = adv_ex.reshape(-1,1,28,28)
f_pred = F(adv_ex) #.size() = [32, 10]
F_Loss = CE_loss(f_pred, labels) #FIGURE OUT HOW TO MAXIMIZE THIS
opt_gen.zero_grad()
loss_G_Final = (G_Loss + F_Loss).to(device)
loss_G_Final.backward()
opt_gen.step()
# Train Discriminator: min log(1-D(x)) +log(D(G(z))))
adv_ex = adv_ex.reshape(32, 784)
disc_real = D(real.clone().detach()).view(-1) #.clone().detach()
disc_fake = D(adv_ex.clone().detach()).view(-1)
D_Loss = torch.mean(torch.log(disc_fake) + torch.log(1. - disc_real))
# can decide later how much that loss term weighs
opt_disc.zero_grad()
D_Loss.backward()
opt_disc.step()
if batch_idx == 0:
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
Loss D: {D_Loss:.4f}, loss G: {loss_G_Final:.4f}"
)
with torch.no_grad():
for item in fixed_input:
item = item.reshape(-1,28*28)
item += G(item)* 0.4
item = item.reshape(-1,1,28,28)
fixed_input = fixed_input.reshape(-1, 1, 28, 28)
data = real.reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fixed_input, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
torch.save(D.state_dict(), "./disc_dict")
torch.save(G.state_dict(), "./gen_dict")
After printing the loss after each batch, it says the Losses are both NaN’s.
How can I make the loss change correctly like I explained above?