Module Buffers not updating in DataParrallel

If I use the DataParrallel utility as such,

 model = torch.nn.DataParallel(model)

my custom batchNorm module does not update it’s buffers, running_avg_mean and running_avg_std. It does update if I run the model without DataParrallel on a single GPU. How can I get the buffers to update in DataParrallel?

# Batch Renormalization for convolutional neural nets (2D) implementation based
# on https://arxiv.org/abs/1702.03275

from torch.nn import Module
import torch

class BatchRenormalization2D(Module):
    '''Batch renorm from https://arxiv.org/pdf/1702.03275.pdf'''

    def __init__(self, num_features, eps=1e-05, momentum=0.01, r_d_max_inc_step=0.0001):
        super(BatchRenormalization2D, self).__init__()

        self.eps = eps
        self.momentum = momentum

        self.gamma = torch.nn.Parameter(torch.ones((1, num_features, 1, 1)), requires_grad=True)
        self.beta = torch.nn.Parameter(torch.zeros((1, num_features, 1, 1)), requires_grad=True)

        self.register_buffer('running_avg_mean', torch.zeros((1, num_features, 1, 1)))
        self.register_buffer('running_avg_std', torch.ones((1, num_features, 1, 1)))

        self.max_r_max = 3.0
        self.max_d_max = 5.0

        self.r_max_inc_step = r_d_max_inc_step
        self.d_max_inc_step = r_d_max_inc_step

        self.r_max = 1.0
        self.d_max = 0.0

    def forward(self, x):

        batch_ch_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
        batch_ch_std = torch.clamp(torch.std(x, dim=(0, 2, 3), keepdim=True), self.eps, 1e10)


        if self.training:

            r = torch.clamp(batch_ch_std / self.running_avg_std, 1.0 / self.r_max, self.r_max).data
            d = torch.clamp((batch_ch_mean - self.running_avg_mean) / self.running_avg_std, -self.d_max, self.d_max).data

            x = ((x - batch_ch_mean) * r )/ batch_ch_std + d
            x = self.gamma * x + self.beta

            if self.r_max < self.max_r_max:
                self.r_max += self.r_max_inc_step * x.shape[0]

            if self.d_max < self.max_d_max:
                self.d_max += self.d_max_inc_step * x.shape[0]

            self.running_avg_mean = self.running_avg_mean + self.momentum * (batch_ch_mean.detach() - self.running_avg_mean)
            self.running_avg_std = self.running_avg_std + self.momentum * (batch_ch_std.detach() - self.running_avg_std)

        else:

            x = (x - self.running_avg_mean) / self.running_avg_std
            x = self.gamma * x + self.beta

        return x

update, was able to make it work using the .lerp() function. Why does what I did not work?

Because in DP, the python module object is replicated to run on each GPU in a different thread. However, this setattr assigns the updated the buffer to the replica, which is lost right afterwards. Instead, inplace updates to the buffer works because buffers in the replica on the first GPU share memory with the original one.

can you please share how you make it work with .lerp()