If I use the DataParrallel utility as such,
model = torch.nn.DataParallel(model)
my custom batchNorm module does not update it’s buffers, running_avg_mean and running_avg_std. It does update if I run the model without DataParrallel on a single GPU. How can I get the buffers to update in DataParrallel?
# Batch Renormalization for convolutional neural nets (2D) implementation based
# on https://arxiv.org/abs/1702.03275
from torch.nn import Module
import torch
class BatchRenormalization2D(Module):
'''Batch renorm from https://arxiv.org/pdf/1702.03275.pdf'''
def __init__(self, num_features, eps=1e-05, momentum=0.01, r_d_max_inc_step=0.0001):
super(BatchRenormalization2D, self).__init__()
self.eps = eps
self.momentum = momentum
self.gamma = torch.nn.Parameter(torch.ones((1, num_features, 1, 1)), requires_grad=True)
self.beta = torch.nn.Parameter(torch.zeros((1, num_features, 1, 1)), requires_grad=True)
self.register_buffer('running_avg_mean', torch.zeros((1, num_features, 1, 1)))
self.register_buffer('running_avg_std', torch.ones((1, num_features, 1, 1)))
self.max_r_max = 3.0
self.max_d_max = 5.0
self.r_max_inc_step = r_d_max_inc_step
self.d_max_inc_step = r_d_max_inc_step
self.r_max = 1.0
self.d_max = 0.0
def forward(self, x):
batch_ch_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
batch_ch_std = torch.clamp(torch.std(x, dim=(0, 2, 3), keepdim=True), self.eps, 1e10)
if self.training:
r = torch.clamp(batch_ch_std / self.running_avg_std, 1.0 / self.r_max, self.r_max).data
d = torch.clamp((batch_ch_mean - self.running_avg_mean) / self.running_avg_std, -self.d_max, self.d_max).data
x = ((x - batch_ch_mean) * r )/ batch_ch_std + d
x = self.gamma * x + self.beta
if self.r_max < self.max_r_max:
self.r_max += self.r_max_inc_step * x.shape[0]
if self.d_max < self.max_d_max:
self.d_max += self.d_max_inc_step * x.shape[0]
self.running_avg_mean = self.running_avg_mean + self.momentum * (batch_ch_mean.detach() - self.running_avg_mean)
self.running_avg_std = self.running_avg_std + self.momentum * (batch_ch_std.detach() - self.running_avg_std)
else:
x = (x - self.running_avg_mean) / self.running_avg_std
x = self.gamma * x + self.beta
return x