Efficient implementation of Shuffle BN in MoCo?

Shuffle BN is an important trick proposed by MoCo (Momentum Contrast for Unsupervised Visual Representation Learning):

We resolve this problem by shuffling BN. We train with multiple GPUs and perform BN on the samples independently for each GPU (as done in common practice). For the key encoder f k , we shuffle the sample order in the current mini-batch before distributing it among GPUs (and shuffle back after encoding); the sample order of the mini-batch for the query encoder f q is not altered. This ensures the batch statistics used to compute a query and its positive key come from two different subsets. This effectively tackles the cheating issue and allows training to benefit from BN.

Since the official code is not yet released, I tried to implement Shuffle BN as below (where the size of local tensor data is [32, 3, 224, 224]):

def forward(self, data):
    N = data.size(0)
    if self.training and self.shuffle_bn:
        global_data = distributed_concat_no_grad(data, 0)
        shuffle_index = torch.randperm(global_data.size(0), device=data.device)
        broadcast(shuffle_index, 0)
        recover_index = shuffle_index.argsort()
        beg = N * self.rank
        end = beg + N
        data = global_data[shuffle_index[beg: end]]
    feature = self.some_feature_extracting_network(data)
    feature = feature.view(N, -1)
    if self.training and self.shuffle_bn:
        global_feature = distributed_concat_with_grad(feature)
        feature = global_feature[recover_index[beg: end]]
    return feature

However, the first call of allgather communication makes the training much slower (0.54s/iter -> 0.84s/iter).

Hey @WarBean

  1. Where is the allgather call? Do you mean the broadcast?
  2. Is this question about how to improve the efficiency?

Thanks for your reply.

1.distributed_concat_no_grad allgather the data tensors on each GPUs.


Looks like, if you can know the value of global_data.size(0) without communication, you then only need the real data from global_data at the end of the if statement. In this case, you can try launch an async allgather and only wait for it right before the shuffle, so that the comm can overlap with other steps in between.

Another questions is why do you need to do the shuffle this way? Can you pre-shuffle the input data for multiple batches and then run multiple iterations without communication? If this is possible, you can both 1) consolidate smaller comm into larger ones and 2) launch multiple async comm and wait for all in one shot to saturate the bandwidth. Besides, looks like the comm only applies to input data, if so, you can even align one iteration with a previous comm, e.g., always let iteration i consume comm result from iteration i - 2. In this way, the comm i-2 might have already finished before kicking off iteration i.