training loop code is here, can anyone help ?
error happen here: d_loss.backward()
for epoch in range(num_epoch):
for i, (img,_) in enumerate(dataloader):
if (i+1)*batch_size < num_samples:
batch_vectors=torch.cat((label_vectors[i*batch_size:(i+1)*batch_size]),0)
batch_vectors=batch_vectors.view(-1,32).float()
else:
batch_vectors=torch.cat((label_vectors[i*batch_size:-1]),0)
batch_vectors=batch_vectors.view(-1,32).float()
num_img = img.size(0)
#train discriminator
# compute loss of real_matched_img
img = img.view(num_img,3,96,96)
real_img = Variable(img).to(device)
real_label = Variable(torch.ones(num_img)).to(device)
fake_label = Variable(torch.zeros(num_img)).to(device)
batch_vectors = Variable(batch_vectors).to(device)
matched_real_out = D(real_img,batch_vectors)
d_loss_matched_real = criterion(matched_real_out, real_label)
matched_real_scores = matched_real_out # closer to 1 means better
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, z_dimension)).to(device)
z = torch.cat((z,batch_vectors),axis=1).to(device)
fake_img = G(z)
matched_fake_out = D(fake_img,batch_vectors)
d_loss_matched_fake = criterion(matched_fake_out, fake_label)
matched_fake_out_scores = matched_fake_out # closer to 0 means better
# compute loss of real_unmatched_img
rand_label_vectors=random.sample(label_vectors,num_img)
rand_batch_vectors=torch.cat((rand_label_vectors[:]),0)
rand_batch_vectors=rand_batch_vectors.view(-1,32).float().to(device)
z = Variable(torch.randn(num_img, z_dimension)).to(device)
z = torch.cat((z,rand_batch_vectors),axis=1).to(device)
fake_img = G(z)
unmatched_real_out = D(fake_img,batch_vectors)
d_loss_unmatched_real = criterion(unmatched_real_out, fake_label)
unmatched_real_out_scores = unmatched_real_out # closer to 0 means better
# bp and optimize
d_loss = d_loss_matched_real + d_loss_matched_fake + d_loss_unmatched_real
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# ===============train generator
# compute loss of fake_img
# compute loss of fake_matched_img
z = Variable(torch.randn(num_img, z_dimension)).to(device)
z = torch.cat((z,batch_vectors),axis=1).to(device)
fake_img = G(z)
matched_fake_out = D(fake_img,batch_vectors)
matched_fake_out_scores = matched_fake_out
g_loss = criterion(matched_fake_out,real_label)
# bp and optimize
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print('Epoch [{}/{}], Batch {},d_loss: {:.6f}, g_loss: {:.6f} '
.format(
epoch, num_epoch,i,d_loss.data, g_loss.data,
))
if epoch == 0:
real_images = to_img(real_img.cpu().data)
save_image(real_images, './img/real_images.png')
fake_images = to_img(fake_img.cpu().data)
save_image(fake_images, './img/fake_images-{}.png'.format(i+(epoch-1)*batch_size))