Shuffle BN is an important trick proposed by MoCo (Momentum Contrast for Unsupervised Visual Representation Learning):
We resolve this problem by shufﬂing BN. We train with multiple GPUs and perform BN on the samples independently for each GPU (as done in common practice). For the key encoder f k , we shufﬂe the sample order in the current mini-batch before distributing it among GPUs (and shufﬂe back after encoding); the sample order of the mini-batch for the query encoder f q is not altered. This ensures the batch statistics used to compute a query and its positive key come from two different subsets. This effectively tackles the cheating issue and allows training to beneﬁt from BN.
Since the official code is not yet released, I tried to implement Shuffle BN as below (where the size of local tensor
data is [32, 3, 224, 224]):
def forward(self, data): N = data.size(0) if self.training and self.shuffle_bn: global_data = distributed_concat_no_grad(data, 0) shuffle_index = torch.randperm(global_data.size(0), device=data.device) broadcast(shuffle_index, 0) recover_index = shuffle_index.argsort() beg = N * self.rank end = beg + N data = global_data[shuffle_index[beg: end]] feature = self.some_feature_extracting_network(data) feature = feature.view(N, -1) if self.training and self.shuffle_bn: global_feature = distributed_concat_with_grad(feature) feature = global_feature[recover_index[beg: end]] return feature
However, the first call of allgather communication makes the training much slower (0.54s/iter -> 0.84s/iter).