I am training a conditional variational autoencoder with ELBO loss function and I am seeing a sudden drop in loss when I apply lr scheduler at some epoch.
Does anyone one have any idea why the learning rate decay drops my loss value suddenly instead of giving me a loss continuous from the loss I get when training with the learning rate before the decay?
Here is the code snippet
def loss_fn(mu_z, std_z, z_sample, mu_x, std_x, x):
S = x.shape[0]
# log posterior q(z|x)
q_z_dist = torch.distributions.Normal(mu_z, torch.exp(std_z))
log_q_z = q_z_dist.log_prob(z_sample)
# log likelihood p(x|z)
p_x_dist = torch.distributions.Normal(mu_x, torch.exp(std_x))
log_p_x = p_x_dist.log_prob(x)
# log prior
p_z_dist = torch.distributions.Normal(0, 1)
log_p_z = p_z_dist.log_prob(z_sample)
loss = (1 / S) * (
torch.sum(log_q_z) - torch.sum(log_p_x) - torch.sum(log_p_z)
)
return torch.sum(log_q_z), torch.sum(log_p_x), torch.sum(log_p_z), loss
optimizer = torch.optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)
train_dataset = TensorDataset(X_train, C_train, K_train)
test_dataset = TensorDataset(X_test, C_test, K_test)
train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_iter = DataLoader(test_dataset, batch_size=BATCH_SIZE)
train_loss_avg = []
test_loss_avg = []
for i in range(N_EPOCHS):
train_loss_avg.append(0)
num_batches = 0
for x, c, k in train_iter:
# zero grad
optimizer.zero_grad()
# forward pass
mu_z, std_z = enc(x.to(device), torch.cat([c, k], axis=1).to(device))
eps = torch.randn_like(std_z)
z_samples = mu_z + eps * torch.exp(std_z)
mu_x, std_x = dec(z_samples.to(device), torch.cat([c, k], axis=1).to(device))
# loss
_, _, _, loss = loss_fn(mu_z, std_z, z_samples, mu_x, std_x, x)
# backward pass
loss.backward()
# update
optimizer.step()
train_loss_avg[-1] += loss.item()
num_batches += 1
if i < 501:
scheduler.step()
train_loss_avg[-1] /= num_batches
with torch.no_grad():
test_loss_avg.append(0)
num_batches = 0
for x_test, c_test, k_test in test_iter:
# forward
mu_z_test, std_z_test = enc(x_test.to(device), torch.cat([c_test, k_test], axis=1).to(device))
eps_test = torch.randn_like(std_z_test)
z_samples_test = mu_z_test + eps_test * torch.exp(std_z_test)
mu_x_test, std_x_test = dec(z_samples_test.to(device), torch.cat([c_test, k_test], axis=1).to(device))
# loss
_, _, _, test_loss = loss_fn(mu_z_test, std_z_test, z_samples_test, mu_x_test, std_x_test, x_test)
test_loss_avg[-1] += test_loss.item()
num_batches += 1
test_loss_avg[-1] /= num_batches
print("Epoch [%d / %d] train loss: %f, test loss: %f" % (i+1, N_EPOCHS, train_loss_avg[-1], test_loss_avg[-1]))