RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. (BatchNorm2d track_running_stats)

If we do a second forward pass on the model having nn.BatchNorm2d with track_running_stats = True, while in ``torch.nn.parallel.DistributedDataParallel` mode, the code throws the error:
RuntimeError: one of the variables needed for gradient computation has been modified by an in-place operation

I’m attaching the code to recreate the error. (debug2.py - Google Drive)
Run the above as :

CUDA_VISIBLE_DEVICES=0 python -W ignore -m torch.distributed.launch --master_port 12345 --nproc_per_node=1 --use_env debug2.py

The only way to resolve the error is either

  • track_running_stats = False, or
  • not running the model in torch.nn.parallel.DistributedDataParallel.

Is there a better way to solve this problem? This problem doesn’t exist outside torch.nn.parallel.DistributedDataParallel. For example, doing model = nn.DataParallel(model).cuda() will work

In the nn.BatchNorm2d implementation

F.batch_norm(
            input,
            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
            self.eps,
)        

does in place an update of self.running_mean & self.running_var which creates the version mismatch.

doing:

self.running_mean_copy = copy.deepcopy(self.running_mean)
self.running_var_copy = copy.deepcopy(self.running_var)

F.batch_norm(
            input,
            self.running_mean_copy if not self.training or self.track_running_stats else None,
            self.running_var_copy if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
                self.eps,
)        
self.running_mean = self.running_mean * 0 + self.running_mean_copy
self.running_var = self.running_var * 0 + self.running_var_copy

solves the issue

Ho this is surprising that DDP is causing this. Could you open an issue on github to track this please?

Done:

Are self.running_mean_copy and self.running_var_copy buffers or are they params?