Shuffle BN is an important trick proposed by MoCo (Momentum Contrast for Unsupervised Visual Representation Learning):
We resolve this problem by shuffling BN. We train with multiple GPUs and perform BN on the samples independently for each GPU (as done in common practice). For the key encoder f k , we shuffle the sample order in the current mini-batch before distributing it among GPUs (and shuffle back after encoding); the sample order of the mini-batch for the query encoder f q is not altered. This ensures the batch statistics used to compute a query and its positive key come from two different subsets. This effectively tackles the cheating issue and allows training to benefit from BN.
Since the official code is not yet released, I tried to implement Shuffle BN as below (where the size of local tensor data
is [32, 3, 224, 224]):
def forward(self, data):
N = data.size(0)
if self.training and self.shuffle_bn:
global_data = distributed_concat_no_grad(data, 0)
shuffle_index = torch.randperm(global_data.size(0), device=data.device)
broadcast(shuffle_index, 0)
recover_index = shuffle_index.argsort()
beg = N * self.rank
end = beg + N
data = global_data[shuffle_index[beg: end]]
feature = self.some_feature_extracting_network(data)
feature = feature.view(N, -1)
if self.training and self.shuffle_bn:
global_feature = distributed_concat_with_grad(feature)
feature = global_feature[recover_index[beg: end]]
return feature
However, the first call of allgather communication makes the training much slower (0.54s/iter → 0.84s/iter).