BatchNorm layers cause training error when track_running_stats=True with DistributedDataParallel

When using DDP (pytorch 12.1) some of my batch norm layers cause the training to fail due to an inplace operation with the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: torch.cuda.FloatTensor [65]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The operation that failed was simply:

class Conv(nn.Module):
    
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2, **bn_params)
        self.act = nn.SiLU(inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)    # <--------------------------------------- this fails
        x = self.act(x)
        return x

On a whim I tried passing:
self.bn = nn.BatchNorm2d(c2, track_running_stats=False, **bn_params)
to all my batch norm layers and the training ran, but of course this is not a viable solution.

For the record, I also tried cloning x and setting nn.SiLU(inplace=False) but got the same error.

I cannot reproduce the issue using your module in:

class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, 1, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU(inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x
    
model = Conv(1, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for _ in range(10):
    optimizer.zero_grad()
    x = torch.randn(1, 1, 24, 24)
    out = model(x)
    out.mean().backward()
    optimizer.step()

Thanks for the response @ptrblck

While creating a minimal reproducible example, I was able to narrow down the error source: During my training I do two forward passes and a backward pass on the resulting losses. If I do only one, there is no error.

Do I assume correctly that the batch norm layers update the running stats during each forward pass? That could explain the error.
If this is the case, why does it only trigger with DDP?

One more thing: I use Huggingface Accelerate for convenience to wrap DDP, but I don’t think it’s relevant in this case.

import torch
import torch.nn as nn
from accelerate import Accelerator

class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, g=1):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, 1, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

accelerate = Accelerator()
model = Conv(1, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model, optimizer = accelerate.prepare(model, optimizer)

with torch.autograd.set_detect_anomaly(True):
    for _ in range(10):
        optimizer.zero_grad()
        x1 = torch.randn(1, 1, 24, 24)
        x2 = torch.randn(1, 1, 24, 24)
        out1 = model(x1)
        out2 = model(x2)
        out = out1 + out2
        out.mean().backward()
        optimizer.step()

Your code snippet still works for me and I don’t get the error after removing Accelerator (I haven’t installed the needed dependencies for it). Do you see the error using your code snippet? If so, does dropping Accelerator help?

You are correct that the running stats get updated in each forward pass, but this won’t explain the error.
The running stats are buffers (i.e. they won’t be updated by Autograd) and are not even used during training.

Do you see the error using your code snippet? If so, does dropping Accelerator help?

If I omit Accelerator the error disappears, but all the package does is wrap DDP around the training.
The following code using pytorch-only modules also produces the error. Tested it on pytorch 1.11.0 and 1.12.1. I set CUDA_VISIBLE_DEVICES=0 and still get the error.

import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, g=1):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, 1, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

def ddp_setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

def main(rank: int, world_size: int):
    ddp_setup(rank, world_size)
    model = Conv(1, 1).to(rank)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model = DDP(model, device_ids=[rank])

    with torch.autograd.set_detect_anomaly(True):
        for _ in range(10):
            optimizer.zero_grad()
            x1 = torch.randn(1, 1, 24, 24).to(rank)
            x2 = torch.randn(1, 1, 24, 24).to(rank)
            out1 = model(x1)
            out2 = model(x2)
            out = out1 + out2
            out.mean().backward()
            optimizer.step()

    destroy_process_group()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size,), nprocs=world_size)

This yields the following error message:

/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/multiprocessing/spawn.py", line 129, in _main
    return self._bootstrap(parent_sentinel)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/anba/catkin_ws/src/tas_dev/dev/anba/superpoint/DDP_bug_2.py", line 37, in main
    out1 = model(x1)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/anba/catkin_ws/src/tas_dev/dev/anba/superpoint/DDP_bug_2.py", line 17, in forward
    x = self.bn(x)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/nn/functional.py", line 2438, in batch_norm
    return torch.batch_norm(
 (Triggered internally at  /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1660083882787/work/torch/csrc/autograd/python_anomaly_mode.cpp:102.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "DDP_bug_2.py", line 48, in <module>
    mp.spawn(main, args=(world_size,), nprocs=world_size)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/anba/catkin_ws/src/tas_dev/dev/anba/superpoint/DDP_bug_2.py", line 40, in main
    out.mean().backward()
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/anba/anaconda3/envs/superpoint_anba/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Were you able to reproduce the error, @ptrblck ?
If not then I give up.

Yes, I’m able to reproduce the issue and it seems it might be raised by usage of different forward passed.
I.e. if I split the forward/backward passes it works for me:

            out1 = model(x1)
            out1.mean().backward()
            out2 = model(x2)
            out2.mean().backward()
            optimizer.step()

I also don’t know if this is a known limitation of DDP, but I think @kwen2501 might know.

Thanks for asking, this is indeed an interesting case. Not sure if bucketizing can affect this. cc: @rvarm1

Is this issue related BatchNorm runtimeError: one of the variables needed for gradient computation has been modified by an inplace operation · Issue #66504 · pytorch/pytorch · GitHub? There are a few workarounds mentioned in the issue

Thanks for link. The issues seem to be related but neither setting broadcast_buffers=False nor using SyncBatchNorm solves my problem.
SyncBatchNorm gives me the same error.
Not broadcasting the buffers lets the training run for a few iterations (the number varies between trainings) but then I get
RuntimeError: Function 'CudnnBatchNormBackward0' returned nan values in its 0th output.

The only solution that kind of worked was wrapping the training in Fully Sharded DDP.

Edit: I tried training on my other machine using pytorch 1.12 (was using 1.11 before) and it worked fine with setting broadcast_buffers=False.