Backward pass error for loss computation in loop

Hello,
I currently train a model where I sample some data iteratively and compute the loss by constructing a sum across iterations. The simplified version of the loss computation looks as follows:

       ctrls_idxs = self.categorical.sample((x_full.shape[0],))
       n_ctrls = self.interpolation_steps[ctrls_idxs]

        running_loss = th.tensor(0.0, device=model.device, dtype=x_full.dtype, requires_grad=False)
        total_elements = th.tensor(0, device=model.device, dtype=x_full.dtype, requires_grad=False)

        unique_cs = n_ctrls.unique()
        for c in unique_cs:
            idxs = th.where(n_ctrls == c)[0]
            x_subset = x_full[idxs]
            timesteps = th.full((x_subset.shape[0],), c, dtype=th.long, device=x_full.device)
            batch_subset = {k: batch[k][idxs].to(model.device) for k in batch.keys()}
            sub_loss = F.mse_loss(model(x_subset, timesteps, batch=batch_subset), batch_subset["action"], reduction="none").sum()
            running_loss = running_loss + sub_loss
            total_elements = total_elements + sub_loss.numel()

        total_loss = running_loss / total_elements
        return {"loss": total_loss}

This throws me the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [9, 2]] is at version 5; expected version 4 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!.

I understand that something goes wrong in the aggregation of the loss. A couple of things are weird to me however which is why I struggle to identify the cause properly:

  • If I disable DDP training, no error occurs
  • If I dont loop over all unique_cs, but instead do something like for c in unique_cs[1:2], no error occurs
  • If I use mixed precision training, ie wrap my loss computation in with th.amp.autocast(device_type=model.device.type, dtype=th.bfloat16):, no error occurs
  • As soon as I switch to with th.amp.autocast(device_type=model.device.type, dtype=th.float32): the error comes back

Can anyone help me on this issue? I would like to understand where exactly the backward pass fails. The th.autograd.set_detect_anomaly(True) option does unfortunately give me no additional hint.

A blanket fallback for resolving these types of “saved variable has been modified inplace” Automatic differentiation package - torch.autograd — PyTorch 2.5 documentation - it will insert a clone automatically for you where you need it to avoid the error.

  • As soon as I switch to with th.amp.autocast(device_type=model.device.type, dtype=th.float32): the error comes back

How you use amp might change things because it can interposing additional out of place ops, preventing the saved tensor from being in-placed.

Copying the source of the allow_mutation_on_saved_tensors and adding some logging to see where it inserts the clone might be insightful, but maybe we can provide something more automatic…