Hi, I am training a resnet50 on 2 dataloaders, and I want to combine the loss. However, it always shows “RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!”
I add torch.autograd.set_detect_anomaly/True), and it shows the problem happens in the batch norm:
"Traceback (most recent call last):
/ext3/miniconda3/lib/python3.11/site-packages/torch/autograd/init.py:200: UserWarning: Error detected in MiopenBatchNormBackward0. Traceback of forward call that caused the error:
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1186, in
run_main()
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 400, in run_main
train_func(args)
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 853, in train_func
train_metrics = train_one_epoch(
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1046, in train_one_epoch
loss = _forward()
File “/scratch/jx2209/pytorch-image-models/train_on_glc_laion.py”, line 1007, in _forward
output = model(input)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/distributed.py”, line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/parallel/distributed.py”, line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 541, in forward
x = self.forward_features(x)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 531, in forward_features
x = self.layer4(x)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py”, line 217, in forward
input = module(input)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/jx2209/pytorch-image-models/timm/models/resnet.py”, line 191, in forward
x = self.bn3(x)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py”, line 171, in forward
return F.batch_norm(
File “/ext3/miniconda3/lib/python3.11/site-packages/torch/nn/functional.py”, line 2450, in batch_norm
return torch.batch_norm(
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
"
I wonder how can I solve this issue as I don’t find a in-place operation here.
My train_one_epoch function is here:
def train_one_epoch(
epoch,
model,
loader, #first dataset
new_loader, #second dataset
optimizer,
loss_fn,
args,
device=torch.device(‘cuda’),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None,
):
torch.autograd.set_detect_anomaly(True)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
#decide larger cycle
max_loader = max(len(loader), len(new_loader))
new_loader = cycle(new_loader)
accum_steps = args.grad_accum_steps
last_accum_steps = max_loader % accum_steps if max_loader % accum_steps != 0 else accum_steps
updates_per_epoch = (max_loader + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = max_loader - 1
last_batch_idx_to_accum = max_loader - last_accum_steps
data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
#for batch_idx, (input, target) in enumerate(loader):
for batch_idx, ((input, target), (new_input, new_target)) in enumerate(zip(loader, new_loader)):
# Adjust batch handling here as per your requirements
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
new_input, new_target = new_input.to(device), new_target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
new_input, new_target = mixup_fn(new_input, new_target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
new_input = new_input.contiguous(memory_format=torch.channels_last)
# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
#train on trusted data
new_output = model(new_input)
loss_2 = loss_fn(new_output, new_target)
#combining the loss together
total_loss = loss + loss_2
if accum_steps > 1:
total_loss = total_loss/ accum_steps
return total_loss
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_update=need_update,
)
else:
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode,
)
optimizer.step()
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now