Sudden drop in loss when lr scheduler is applied

I am training a conditional variational autoencoder with ELBO loss function and I am seeing a sudden drop in loss when I apply lr scheduler at some epoch.
Does anyone one have any idea why the learning rate decay drops my loss value suddenly instead of giving me a loss continuous from the loss I get when training with the learning rate before the decay?

Here is the code snippet

def loss_fn(mu_z, std_z, z_sample, mu_x, std_x, x):
    S = x.shape[0]
    # log posterior q(z|x)
    q_z_dist = torch.distributions.Normal(mu_z, torch.exp(std_z))
    log_q_z = q_z_dist.log_prob(z_sample)
    # log likelihood p(x|z)
    p_x_dist = torch.distributions.Normal(mu_x, torch.exp(std_x))
    log_p_x = p_x_dist.log_prob(x)
    # log prior 
    p_z_dist = torch.distributions.Normal(0, 1)
    log_p_z = p_z_dist.log_prob(z_sample)
    loss = (1 / S) * (
        torch.sum(log_q_z) - torch.sum(log_p_x) - torch.sum(log_p_z)
    return torch.sum(log_q_z), torch.sum(log_p_x), torch.sum(log_p_z), loss

optimizer = torch.optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

train_dataset = TensorDataset(X_train, C_train, K_train)
test_dataset = TensorDataset(X_test, C_test, K_test)

train_iter = DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_iter = DataLoader(test_dataset, batch_size=BATCH_SIZE)

train_loss_avg = []
test_loss_avg = []

for i in range(N_EPOCHS):
    num_batches = 0
    for x, c, k in train_iter: 
        # zero grad 
        # forward pass 
        mu_z, std_z = enc(,[c, k], axis=1).to(device))
        eps = torch.randn_like(std_z)
        z_samples = mu_z + eps * torch.exp(std_z)
        mu_x, std_x = dec(,[c, k], axis=1).to(device))
        # loss 
        _, _, _, loss = loss_fn(mu_z, std_z, z_samples, mu_x, std_x, x)
        # backward pass 
        # update 
        train_loss_avg[-1] += loss.item()
        num_batches += 1
    if i < 501: 

    train_loss_avg[-1] /= num_batches
    with torch.no_grad():
        num_batches = 0 
        for x_test, c_test, k_test in test_iter: 
            # forward
            mu_z_test, std_z_test = enc(,[c_test, k_test], axis=1).to(device))
            eps_test = torch.randn_like(std_z_test)
            z_samples_test = mu_z_test + eps_test * torch.exp(std_z_test)
            mu_x_test, std_x_test = dec(,[c_test, k_test], axis=1).to(device))
            # loss
            _, _, _, test_loss = loss_fn(mu_z_test, std_z_test, z_samples_test, mu_x_test, std_x_test, x_test)
            test_loss_avg[-1] += test_loss.item()
            num_batches += 1
        test_loss_avg[-1] /= num_batches
    print("Epoch [%d / %d] train loss: %f, test loss: %f" % (i+1, N_EPOCHS, train_loss_avg[-1], test_loss_avg[-1]))

This effect is often observed when decreasing the learning rate, as your loss might be “stuck” due to too large gradient steps. E.g. the ResNet paper shows the same behavior in the loss curves (this paper of course doesn’t discuss this effect).