How to replace usage of "retain_graph=True"

Hi all. I’ve generally seen it recommended against using the retain_graph parameter, but I can’t seem to get a piece of my code working without it. In particular, this is the code where I use retain_graph (the goal here is to optimize a single scalar cql_alpha value:

# calculate eq. 5 in updated SAC paper
        qf1_values = qf1(observations)
        qf2_values = qf2(observations)
        qf1_a_values = qf1_values.gather(1, actions)
        qf2_a_values = qf2_values.gather(1, actions)
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = 0.5 * (qf1_loss + qf2_loss)  # scaling added from CQL

        # calculate CQL regularization loss
        cql_qf1_diff = torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
        cql_qf2_diff = torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
        if args.cql_autotune:
            cql_alpha = torch.clamp(torch.exp(cql_log_alpha), min=0.0, max=1000000.0)
            cql_qf1_loss = cql_alpha * (cql_qf1_diff - args.difference_threshold)
            cql_qf2_loss = cql_alpha * (cql_qf2_diff - args.difference_threshold)
            cql_alpha_loss = -(cql_qf1_loss + cql_qf2_loss)

            # ---------- update cql_alpha ---------- #
            cql_a_optimizer.zero_grad()
            cql_alpha_loss.backward(retain_graph=True)
            cql_a_optimizer.step()
        else:
            cql_qf1_loss = cql_alpha * cql_qf1_diff
            cql_qf2_loss = cql_alpha * cql_qf2_diff

        # calculate final q-function loss which is a combination of Bellman error and CQL regularization
        cql_qf_loss = cql_qf1_loss + cql_qf2_loss
        qf_loss = cql_qf_loss + qf_loss

        # calculate eq. 6 in updated SAC paper
        q_optimizer.zero_grad()
        qf_loss.backward()
        q_optimizer.step()

I tried the following approach to remove the usage, but while it seems to work initially, eventually the Q-values explode to infinity with this approach.

# calculate eq. 5 in updated SAC paper
        qf1_values = qf1(observations)
        qf2_values = qf2(observations)
        qf1_a_values = qf1_values.gather(1, actions)
        qf2_a_values = qf2_values.gather(1, actions)
        qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
        qf_loss = 0.5 * (qf1_loss + qf2_loss)  # scaling added from CQL

        # calculate CQL regularization loss
        if args.cql_autotune:
            cql_alpha = torch.clamp(
                cql_log_alpha.exp().detach(), min=0.0, max=1000000.0
            )
            cql_qf1_loss = (
                torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
            )
            cql_qf2_loss = (
                torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
            )
            cql_qf1_loss = cql_alpha * (cql_qf1_loss - args.difference_threshold)
            cql_qf2_loss = cql_alpha * (cql_qf2_loss - args.difference_threshold)
        else:
            cql_qf1_loss = cql_alpha * (
                torch.logsumexp(qf1_values, dim=1).mean() - qf1_a_values.mean()
            )
            cql_qf2_loss = cql_alpha * (
                torch.logsumexp(qf2_values, dim=1).mean() - qf2_a_values.mean()
            )

        # calculate final q-function loss which is a combination of Bellman error and CQL regularization
        cql_qf_loss = cql_qf1_loss + cql_qf2_loss
        qf_loss = cql_qf_loss + qf_loss

        # calculate eq. 6 in updated SAC paper
        q_optimizer.zero_grad()
        qf_loss.backward()
        q_optimizer.step()

        # ---------- update cql_alpha ---------- #
        if args.cql_autotune:
            with torch.no_grad():
                cql_qf1_diff_loss = cql_qf1_loss - args.difference_threshold
                cql_qf2_diff_loss = cql_qf2_loss - args.difference_threshold

            cql_alpha = torch.clamp(cql_log_alpha.exp(), min=0.0, max=1000000.0)
            cql_alpha_loss = cql_alpha * -(cql_qf1_diff_loss + cql_qf2_diff_loss)

            cql_a_optimizer.zero_grad()
            cql_alpha_loss.backward()
            cql_a_optimizer.step()

Would anyone be able to help me figure out what I might be doing wrong?

I personally think it’s ok to use retain_graph if you have multiple losses where part of the graph is shared. You should be careful about calling zero_grad where appropriate to make sure that one backward does not impact another (you generally don’t want gradients to accumulate in multiple places).

To answer your question, I don’t see any obvious difference between the two. It feels like your first example could run without retain_graph as you only need to differentiate through a very small graph.
The second solution definitely makes more sense.

Exploding Q-values can happen for multiple reasons with SAC-like algos: wrong implementations, wrong optimization of the entropy coefficient, …
One thing is also numerical instability of action sampled around the bounds of the action space when using transforms. You may want to look into that.
One thing I’ve found to help is to use a small weight decay in the optimizer(s).

1 Like