Hi all. I’ve generally seen it recommended against using the retain_graph parameter, but I can’t seem to get a piece of my code working without it. In particular, this is the code where I use retain_graph (the goal here is to optimize a single scalar cql_alpha
value:
# calculate eq. 5 in updated SAC paper
qf1_values = qf1(observations)
qf2_values = qf2(observations)
qf1_a_values = qf1_values.gather(1, actions)
qf2_a_values = qf2_values.gather(1, actions)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = 0.5 * (qf1_loss + qf2_loss) # scaling added from CQL
# calculate CQL regularization loss
cql_qf1_diff = torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
cql_qf2_diff = torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
if args.cql_autotune:
cql_alpha = torch.clamp(torch.exp(cql_log_alpha), min=0.0, max=1000000.0)
cql_qf1_loss = cql_alpha * (cql_qf1_diff - args.difference_threshold)
cql_qf2_loss = cql_alpha * (cql_qf2_diff - args.difference_threshold)
cql_alpha_loss = -(cql_qf1_loss + cql_qf2_loss)
# ---------- update cql_alpha ---------- #
cql_a_optimizer.zero_grad()
cql_alpha_loss.backward(retain_graph=True)
cql_a_optimizer.step()
else:
cql_qf1_loss = cql_alpha * cql_qf1_diff
cql_qf2_loss = cql_alpha * cql_qf2_diff
# calculate final q-function loss which is a combination of Bellman error and CQL regularization
cql_qf_loss = cql_qf1_loss + cql_qf2_loss
qf_loss = cql_qf_loss + qf_loss
# calculate eq. 6 in updated SAC paper
q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()
I tried the following approach to remove the usage, but while it seems to work initially, eventually the Q-values explode to infinity with this approach.
# calculate eq. 5 in updated SAC paper
qf1_values = qf1(observations)
qf2_values = qf2(observations)
qf1_a_values = qf1_values.gather(1, actions)
qf2_a_values = qf2_values.gather(1, actions)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = 0.5 * (qf1_loss + qf2_loss) # scaling added from CQL
# calculate CQL regularization loss
if args.cql_autotune:
cql_alpha = torch.clamp(
cql_log_alpha.exp().detach(), min=0.0, max=1000000.0
)
cql_qf1_loss = (
torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
)
cql_qf2_loss = (
torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
)
cql_qf1_loss = cql_alpha * (cql_qf1_loss - args.difference_threshold)
cql_qf2_loss = cql_alpha * (cql_qf2_loss - args.difference_threshold)
else:
cql_qf1_loss = cql_alpha * (
torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
)
cql_qf2_loss = cql_alpha * (
torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
)
# calculate final q-function loss which is a combination of Bellman error and CQL regularization
cql_qf_loss = cql_qf1_loss + cql_qf2_loss
qf_loss = cql_qf_loss + qf_loss
# calculate eq. 6 in updated SAC paper
q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()
# ---------- update cql_alpha ---------- #
if args.cql_autotune:
with torch.no_grad():
cql_qf1_diff_loss = cql_qf1_loss - args.difference_threshold
cql_qf2_diff_loss = cql_qf2_loss - args.difference_threshold
cql_alpha = torch.clamp(cql_log_alpha.exp(), min=0.0, max=1000000.0)
cql_alpha_loss = cql_alpha * -(cql_qf1_diff_loss + cql_qf2_diff_loss)
cql_a_optimizer.zero_grad()
cql_alpha_loss.backward()
cql_a_optimizer.step()
Would anyone be able to help me figure out what I might be doing wrong?