Using multiple optimizers inside a loop with Pytorch lightning

I am working on a big project which for which I need to call manual_backward and optimizer.step inside a loop for every batch.
Here is some reference code for a training_function that works, and another that doesn’t:

def loss_fn_working(self, batch: Any, batch_idx: int):
    env = self.envs[self.p]
    actions = None
    prev_log_rewards = torch.empty(env.done.shape[0]).type_as(env.state)
    prev_forward_logprob = None
    loss = torch.tensor(0.0, requires_grad=True)
    TERM = env.terminal_index

    while not torch.all(env.done):
        active = ~env.done
        forward_logprob, back_logprob = self.forward(env)

        log_rewards = -self.get_rewards()

        if actions is not None:
            error = log_rewards - prev_log_rewards[active]
            error += back_logprob.gather(1, actions[actions != TERM, None]).squeeze(1)
            error += prev_forward_logprob[active, -1]
            error -= forward_logprob[:, -1].detach()
            error -= (
                prev_forward_logprob[active]
                .gather(1, actions[actions != TERM, None])
                .squeeze(1)
            )
            loss = loss + F.huber_loss(
                error,
                torch.zeros_like(error),
                delta=1.0,
                reduction="none",
            )
            loss = loss * log_rewards.softmax(0)
            loss = loss.mean(0)

        actions = self.sample_actions(forward_logprob, active, TERM)
        env.step(actions)

        # save previous log-probs and log-rewards
        if prev_forward_logprob is None:
            prev_forward_logprob = torch.empty_like(forward_logprob)
        prev_forward_logprob[active] = forward_logprob
        prev_log_rewards[active] = log_rewards

    return loss, log_rewards


def calculate_loss(
        self,
        loss,
        log_rewards,
        prev_log_rewards,
        back_log_prob,
        actions,
        stop_prob,
        prefix, # Added for debugging
        prev_stop_prob=None,
        prev_forward_log_prob=None,
    ):
        error = torch.tensor(0.0, requires_grad=True) + log_rewards - prev_log_rewards  # [B]
        error = error + (back_log_prob).gather(1, actions.unsqueeze(1)).squeeze(1)  # P_B(s|s')
        error = error - stop_prob.detach()  # P(s_f|s')
        if prev_stop_prob is not None and prev_forward_log_prob is not None:
            error = error + prev_stop_prob.detach()  # P(s_f|s)
            error = error - (prev_forward_log_prob).gather(
                1, actions.unsqueeze(1)
            ).squeeze(1)

        loss = loss + F.huber_loss(  # accumulate losses
            error,
            torch.zeros_like(error),
            delta=1.0,
            reduction="none",
        )

        loss = loss * log_rewards.softmax(0)
        return loss.mean(0)

def loss_fn_not_working(self, batch, batch_size, prefix, batch_idx):
    gfn_opt, rep_opt = self.optimizers()
    # some code here
    losses = []
    rep_losses = []
    prev_forward_log_prob = None
    prev_stop_prob = torch.zeros(batch_size, device='cuda')
    loss = torch.tensor(0.0, requires_grad=True, device='cuda')
    active = torch.ones((batch_size,), dtype=bool, device='cuda')
    graph = torch.diag_embed(torch.ones(batch_size, self.n_dim)).cuda()
    while active.any():
        graph_hat = graph[active].clone()
        adj_mat = graph_hat.clone()
        rep_loss, latent_var = self.rep_model(torch.cat((adj_mat, next_id.unsqueeze(-1)), axis = -1))
        rep_loss_tensor = torch.tensor(0.0, requires_grad=True) + rep_loss

        forward_log_prob, Fs_masked, back_log_prob, next_prob, stop_prob = (
            self.gfn_model(latent_var)
        )
        with torch.no_grad():
            actions = self.sample_actions(Fs_masked)
            graph = self.update_graph(actions)
            #######################

            log_rewards = -self.energy_model(graph_hat, batch, False, self.current_epoch)

        if counter==0:
            loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix)
        else:
            loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix, prev_stop_prob[active], prev_forward_log_prob[active])
        losses.append(loss.item())
        rep_losses.append(rep_loss.item())
        if prefix == 'train':
            rep_opt.zero_grad()
            self.manual_backward(rep_loss_tensor, retain_graph=True)
            rep_opt.step()
            gfn_opt.zero_grad()
            self.manual_backward(loss)
            self.clip_gradients(gfn_opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm") # NEEDED??
            gfn_opt.step()

        with torch.no_grad():
            active[indices_to_deactivate] = False #active updated appropriately
            indices = indices[~current_stop]
            # active_indices = ~current_stop # Not being used?
            next_id = F.one_hot(indices, num_classes=self.n_dim)
            prev_log_rewards = log_rewards[~current_stop]
            counter += 1
        if prev_forward_log_prob is None:
            prev_forward_log_prob = torch.empty_like(forward_log_prob)
        prev_forward_log_prob[active] = forward_log_prob[~current_stop]
        prev_stop_prob[active] = stop_prob[~current_stop]

    return losses, graph, log_rewards, counter, rep_losses

Here, the main variable of importance is prev_forward_log_prob in loss_fn_not_working. The loss is being calculated using calculate_loss() function.
I have kept manual_optimization as True.

When using loss_fn_not_working, and keeping retain_graph as false for loss, I get the following error:

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

If I do keep retain_graph as True for loss (i.e. the loss for the second optimizer), I get the following error instead:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 10]], which is output 0 of AsStridedBackward0, is at version 3; expected version 1 instead. 

If I use loss_fn_working, there is no problem. So, I understand that the problem arises when using backward calls inside the loop. I am not really making any in-place operations, so how can I make the 2nd loss_fn work? I tried cloning the relevant variables but it doesn’t work until we detach them, which I can’t do.

@ptrblck I saw your previous answers on similar errors, but none of the solutions seem to be working. Do you know why this could be happening, and how I can change it to make it work?

I think I can break down the code to have a minimal and cleaner example:

while active.any():
    forward_log_prob = self.model(inputs)
    actions = self.sample_actions() # actions does not require_grad
    error = torch.tensor(0.0, requires_grad=True)
    error = error - prev_forward_log_prob.gather(1, actions.unsqueeze(1)).squeeze(1)
    loss =  F.huber_loss(
                    error,
                    torch.zeros_like(error),
                    delta=1.0,
                    reduction="none",
                )
    if prefix == 'train':
          self.manual_backward(loss, retain_graph=True)
    prev_forward_log_prob[active] = forward_log_prob

Here, error, loss, forward_log_prob, and prev_forward_log_prob have requires_grad=True.
active, actions and other variables do not.
When counter=2, self.manual_backward gives the following error:

one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABoolType [512]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
  File "/var/local/aurora/GFNSSD/src/models/parallel_energy_gfn_module.py", line 577, in detailed_balance
    self.manual_backward(loss, retain_graph=True)
  File "/var/local/aurora/GFNSSD/src/models/parallel_energy_gfn_module.py", line 616, in training_step
    final_gfn_loss, graphs, avg_log_rewards, traj_length, final_rep_loss = self.detailed_balance(batch, self.B, prefix='train', batch_idx=batch_idx)
                                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "path/src/training_pipeline.py", line 147, in train
    trainer.fit(model=model, datamodule=datamodule)
  File "path/train.py", line 22, in main
    return train(config)
           ^^^^^^^^^^^^^
  File "path/train.py", line 26, in <module>
    main()
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABoolType [512]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Please help me out here!