Inplace operation error for gradient computation when trained on two dataloaders

Hi, I am training a resnet50 on 2 dataloaders, and I want to combine the loss. However, it always shows “RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!”

I add torch.autograd.set_detect_anomaly/True), and it shows the problem happens in the batch norm:
"Traceback (most recent call last):
/ext3/miniconda3/lib/python3.11/site-packages/torch/autograd/init.py:200: UserWarning: Error detected in MiopenBatchNormBackward0. Traceback of forward call that caused the error:
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1186, in
run_main()
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 400, in run_main
train_func(args)
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 853, in train_func
train_metrics = train_one_epoch(
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1046, in train_one_epoch
loss = _forward()
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1007, in _forward
output = model(input)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/distributed.py”, line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/distributed.py”, line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 541, in forward
x = self.forward_features(x)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 531, in forward_features
x = self.layer4(x)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py”, line 217, in forward
input = module(input)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 191, in forward
x = self.bn3(x)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py”, line 171, in forward
return F.batch_norm(
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py”, line 2450, in batch_norm
return torch.batch_norm(
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
"

I wonder how can I solve this issue as I don’t find a in-place operation here.

My train_one_epoch function is here:

def train_one_epoch(
epoch,
model,
loader, #first dataset
new_loader, #second dataset
optimizer,
loss_fn,
args,
device=torch.device(‘cuda’),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None,
):
torch.autograd.set_detect_anomaly(True)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False

second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()

model.train()

#decide larger cycle
max_loader = max(len(loader), len(new_loader))
new_loader = cycle(new_loader)

accum_steps = args.grad_accum_steps
last_accum_steps = max_loader % accum_steps if max_loader % accum_steps != 0 else accum_steps
updates_per_epoch = (max_loader + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = max_loader - 1
last_batch_idx_to_accum = max_loader - last_accum_steps

data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
#for batch_idx, (input, target) in enumerate(loader):
for batch_idx, ((input, target), (new_input, new_target)) in enumerate(zip(loader, new_loader)):
# Adjust batch handling here as per your requirements

    last_batch = batch_idx == last_batch_idx
    need_update = last_batch or (batch_idx + 1) % accum_steps == 0
    update_idx = batch_idx // accum_steps
    if batch_idx >= last_batch_idx_to_accum:
        accum_steps = last_accum_steps

    if not args.prefetcher:
        input, target = input.to(device), target.to(device)
        new_input, new_target = new_input.to(device), new_target.to(device)
        if mixup_fn is not None:
            input, target = mixup_fn(input, target)
            new_input, new_target = mixup_fn(new_input, new_target)
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
        new_input = new_input.contiguous(memory_format=torch.channels_last)
    # multiply by accum steps to get equivalent for full update
    data_time_m.update(accum_steps * (time.time() - data_start_time))

    def _forward():
       
        with amp_autocast():
            output = model(input)
            
            loss = loss_fn(output, target)
            #train on trusted data
            new_output = model(new_input)
            loss_2 = loss_fn(new_output, new_target)
            #combining the loss together
            total_loss = loss + loss_2
        if accum_steps > 1:
            total_loss = total_loss/ accum_steps
        return total_loss

    def _backward(_loss):
        if loss_scaler is not None:
            loss_scaler(
                _loss,
                optimizer,
                clip_grad=args.clip_grad,
                clip_mode=args.clip_mode,
                parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
                create_graph=second_order,
                need_update=need_update,
            )
        else:
            _loss.backward(create_graph=second_order)
            if need_update:
                if args.clip_grad is not None:
                    utils.dispatch_clip_grad(
                        model_parameters(model, exclude_head='agc' in args.clip_mode),
                        value=args.clip_grad,
                        mode=args.clip_mode,
                    )
                optimizer.step()

    if has_no_sync and not need_update:
        with model.no_sync():
            loss = _forward()
            _backward(loss)
    else:
        loss = _forward()
        _backward(loss)

    if not args.distributed:
        losses_m.update(loss.item() * accum_steps, input.size(0))
    update_sample_count += input.size(0)

    if not need_update:
        data_start_time = time.time()
        continue

    num_updates += 1
    optimizer.zero_grad()
    if model_ema is not None:
        model_ema.update(model)

    if args.synchronize_step and device.type == 'cuda':
        torch.cuda.synchronize()
    time_now = time.time()
    update_time_m.update(time.time() - update_start_time)
    update_start_time = time_now

Hi Jiawei!

First, look in your code for tensors of shape [2048] that may be the tensor
that is being modified. Then for any plausible candidate tensor, t, print out
t._version at various place in your code to see if that tensor’s ._version
changes from 2 to 3. You can then use a binary-search strategy by printing
out t._version at additional places in your code to determine the specific
operation that is modifying t inplace.

If it turns out that t._version increases from 2 to 3 when you call into some
third-party code, continue your binary search adding print statements to the
third-party code you are calling into.

Once you’ve found the source of the inplace modification, the fix for the error
is sometimes relatively clear.

Further discussion of how to debug inplace-modification errors can be found
in the following post:

Good luck!

K. Frank

Hi Frank,
Thanks so much for your reply! I am using a standard renset50 and I am sure the error occurs at the batch norm after a short connection. However, I just wonder why it has an error here just by combining two losses on two dataloaders.