Synchronization mechanism with different iteration counts of for-loop

Hi,
I’m working on modifying my model (including my custom data loader) to fit the structure of DDP. I haven’t given my code a try but I’d like to know more about the synchronization process.

According to the many great threads on this forum, DDP takes care of the synchronization during loss.backward(). But what if the number of data in each data loader leads to different for-loop counts, would the processes with n+1 loops be blocked because the processes with n loops never reach the point?

Say, I have 401 images, distributed to 4 data loaders with 101, 100, 100, 100 images respectively. Batch size is 4 so process 0 gets 26 iterations while others get 25. Would my process group get stuck at 26th iteration?

Here is a simplified version of part of my code:

#......(some init process including moving self.model to DDP)......
for phase in ['train', 'eval']:
    dist.barrier()
    if phase=='train':
        self.model.train()
        self.data_loader.train()
    else:
        self.model.eval()
        self.data_loader.eval()
    running_loss = 0
    for inputs, labels in self.data_loader:
        self.optimizer.zero_grad()
        with torch.set_grad_enabled(phase=='train'):
            outputs = self.model(inputs)
            loss = self.loss(outputs, labels)
            if phase == 'train':
                loss.backward()   ### Could this or the following line get stuck during the extra loop by process 0?
                self.optimizer.step()
                running_loss += loss.item()*inputs.shape[0]
        torch.cuda.empty_cache()
    epoch_loss = running_loss/len(self.data_loader)

Thanks for any helpful hint!

Yep, the one with n+1 loops will block when using <= PyTorch v1.6. There are ways to get around in user code, e.g. by collecting a signal in each iteration to see if any process has already exited. If yes, break.

@rvarm1 is working on a much better solution, which will be included in v1.7. With that solution, the process that exits early will use dummy comm ops to unblock remaining active ones. Please see the following issue and PR.

Thousand thanks for the explanation! I modified my code following your suggestion and I provide my provisional solution here for comments.


running_loss = 0
running_len = 0
for inputs, labels in self.data_loader:
    self.optimizer.zero_grad()
    with torch.set_grad_enabled(phase=='train'):
        outputs = self.model(inputs)
        loss = self.loss(outputs, labels)
        if phase == 'train':
            loss.backward()
            self.optimizer.step()
            iteration_count+=1
    running_loss += loss.item()
    running_len += inputs.shape[0]
    torch.cuda.empty_cache()
    ##########
    is_next = torch.Tensor([self.data_loader.peek()])
    # is_next==True if the iterator has not reached the end, i.e., next loop is expected
    dist.all_reduce_multigpu(is_next, op=dist.ReduceOp.BAND)
    if not is_next: break
    ##########

Hey @annisat, that looks good to me. One way to speed it up a bit is to run the dist.all_reduce at the beginning of the loop and set async_op=True. Then only wait for it when you need the result. In this way, the comm and the forward/backward/opt.step computation can overlap. Please see the code in the following thread:

Thanks for the tips! It took me some while to understand and implement async_op.

I would like to point out a problem when I ran my own code above.

I changed my code to

is_next = torch.Tensor([self.data_loader.peek()]).cuda(self.gpu)
col_handle = dist.all_reduce(is_next, op=dist.ReduceOp.BAND, async_op)
...
col_handle.wait()
if not is_next: break

and tried it with SPSG with 2 processes. The final value of is_next is [2] rather than [True] or [1]. It seems that dist.ReduceOp.BAND adds up input tensors rather than doing a regular AND. Therefore I changed the first line into:

is_next = torch.Tensor([self.data_loader.peek()]).bool().cuda(self.gpu)

The Error Message says all_reduce does not support this Tensor type for now. In order to achieve my goal, I use dist.ReduceOp.MIN instead. Here’s my final code that actually runs smoothly without imbalanced for-loop counts blocking the synchornization process.

for inputs, labels in self.data_loader:
    is_next = torch.Tensor([self.data_loader.peek()]).cuda(self.gpu)
    col_handle = dist.all_reduce(is_next, op=dist.ReduceOp.MIN, async_op=True)
    # forward and backward and step and stuff                    
    col_handle.wait()
    if not is_next: break