I am using VAE and GAN and training them combined.
I have 4 paramets theta_s, theta_e, theta_d, theta_c
theta_s is for some spatial loss
theta_e is for encoder
theta_d is for decoder
theta_c is for discriminator
My objectives are
- For learning {theta_s, theta_e} => minimize (Lreconst +Lprior +Lsparsity )
- For learning {theta_d} => minimize(Lreconst + Lgan)
- For learning {theta_c} => maximize(Lgan)
here is the code below for training
def train(inputs=[], sLSTM=None, eLSTM=None, dLSTM=None, cLSTM=None, epochs=100, lr=0.01):
"""
Train all models at once
"""
losses = np.zeros((epochs, 4))
S_solver = optim.SGD(sLSTM.parameters(), lr=lr)
E_solver = optim.SGD(eLSTM.parameters(), lr=lr)
D_solver = optim.SGD(dLSTM.parameters(), lr=lr)
C_solver = optim.SGD(cLSTM.parameters(), lr=lr)
for epoch in xrange(1, epochs+1):
# batch_score = 0
for X in inputs:
# Selector LSTM score computation
sLSTM.hidden = sLSTM.init_hidden()
frames_scores = sLSTM(X) # scores for every frame in a mini batch
# print frames_scores
S_loss = spatial_loss(frames_scores.sum())
# batch_score += frames_scores.sum() # summation of scores for whole batch
# Multiply frames scores with input
X = frames_scores.view(-1, 1) * X
X.detach_() # detach its history from selector lstm
X.volatile = False
# Moving into eLSTM
""" Reconstruction Phase """
z_sample = eLSTM(X)
X_sample = dLSTM(z_sample)
recon_loss = F.binary_cross_entropy(X_sample, X)
""" Regularization phase """
# Disciminator or cLSTM
z_real = Variable(torch.randn(BATCH_DIM, Z_DIM))
z_fake = eLSTM(X)
C_real = cLSTM(z_real)
C_fake = cLSTM(z_fake)
C_loss = -torch.mean(torch.log(C_real) + torch.log(1 - C_fake)) # Prior loss
mix_loss = recon_loss + S_loss + C_loss
mix_loss.backward()
S_solver.step()
E_solver.step()
# Reset gradients
sLSTM.zero_grad()
eLSTM.zero_grad()
dLSTM.zero_grad()
cLSTM.zero_grad()
# Generator
z_fake = eLSTM(X)
C_fake = cLSTM(z_fake)
g_loss = torch.mean(torch.log(C_fake))
G_loss = -g_loss
mix_loss2 = recon_loss + g_loss
mix_loss2.backward()
D_solver.step()
# Reset gradients
eLSTM.zero_grad()
dLSTM.zero_grad()
cLSTM.zero_grad()
G_loss.backward()
C_solver.step()
# Reset gradients
eLSTM.zero_grad()
dLSTM.zero_grad()
cLSTM.zero_grad()
# Spatial loss
# S_loss = spatial_loss(batch_score)
# S_solver.zero_grad()
# S_loss.backward()
# S_solver.step()
losses[epoch-1] = S_loss.data[0], recon_loss.data[0], G_loss.data[0], C_loss.data[0]
if epoch % 10 == 0:
print("Epoch: %d S_loss = %.4f; R_loss = %.4f; G_loss = %.4f; C_loss = %.4f" % (
epoch, S_loss.data[0], recon_loss.data[0], G_loss.data[0], C_loss.data[0]
))
return losses
Now , the code was working fine for 1 and 2 objective , but when i introduced mixloss2 i.e for object 2 , its giving me error in line mix_loss2.backward()
and saying put retrain_graph=True
so can anyone help me out , what am i doing wrong here ?