If we do a second forward pass on the model having nn.BatchNorm2d with track_running_stats = True, while in ``torch.nn.parallel.DistributedDataParallel` mode, the code throws the error: RuntimeError: one of the variables needed for gradient computation has been modified by an in-place operation
not running the model in torch.nn.parallel.DistributedDataParallel.
Is there a better way to solve this problem? This problem doesn’t exist outside torch.nn.parallel.DistributedDataParallel. For example, doing model = nn.DataParallel(model).cuda() will work
F.batch_norm(
input,
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
does in place an update of self.running_mean & self.running_var which creates the version mismatch.
doing:
self.running_mean_copy = copy.deepcopy(self.running_mean)
self.running_var_copy = copy.deepcopy(self.running_var)
F.batch_norm(
input,
self.running_mean_copy if not self.training or self.track_running_stats else None,
self.running_var_copy if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
self.running_mean = self.running_mean * 0 + self.running_mean_copy
self.running_var = self.running_var * 0 + self.running_var_copy