Hello,
I currently train a model where I sample some data iteratively and compute the loss by constructing a sum across iterations. The simplified version of the loss computation looks as follows:
ctrls_idxs = self.categorical.sample((x_full.shape[0],))
n_ctrls = self.interpolation_steps[ctrls_idxs]
running_loss = th.tensor(0.0, device=model.device, dtype=x_full.dtype, requires_grad=False)
total_elements = th.tensor(0, device=model.device, dtype=x_full.dtype, requires_grad=False)
unique_cs = n_ctrls.unique()
for c in unique_cs:
idxs = th.where(n_ctrls == c)[0]
x_subset = x_full[idxs]
timesteps = th.full((x_subset.shape[0],), c, dtype=th.long, device=x_full.device)
batch_subset = {k: batch[k][idxs].to(model.device) for k in batch.keys()}
sub_loss = F.mse_loss(model(x_subset, timesteps, batch=batch_subset), batch_subset["action"], reduction="none").sum()
running_loss = running_loss + sub_loss
total_elements = total_elements + sub_loss.numel()
total_loss = running_loss / total_elements
return {"loss": total_loss}
This throws me the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [9, 2]] is at version 5; expected version 4 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
.
I understand that something goes wrong in the aggregation of the loss. A couple of things are weird to me however which is why I struggle to identify the cause properly:
- If I disable DDP training, no error occurs
- If I dont loop over all unique_cs, but instead do something like for
c in unique_cs[1:2]
, no error occurs - If I use mixed precision training, ie wrap my loss computation in
with th.amp.autocast(device_type=model.device.type, dtype=th.bfloat16):
, no error occurs - As soon as I switch to
with th.amp.autocast(device_type=model.device.type, dtype=th.float32):
the error comes back
Can anyone help me on this issue? I would like to understand where exactly the backward pass fails. The th.autograd.set_detect_anomaly(True)
option does unfortunately give me no additional hint.