Need to propagate EARLY_STOP flag from GPU#0 to all other GPUs in a DDP setting

I am training my DDP model with 2 GPUs on a single node. I use a DistributedSampler for the training loop. The validation is done on GPU#0 only. Before training, I initialize EARLY_STOP flag with zero on all processes. GPU#0 updates its EARLY_STOP value when the criterion is met. I then tried both broadcast and all_reduce (outside the if gpu_id==0 statement) and I keep getting RuntimeError on the line:

torch.distributed.broadcast(EARLY_STOP, src=0)

The errors are:

RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(OpType=BROADCAST, TensorShape=[], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(OpType=ALLGATHER_COALESCED).

and

RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(OpType=BROADCAST, TensorShape=[330], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 1 is running collective: CollectiveFingerPrint(OpType=_ALLGATHER_BASE).

Why am I getting these errors? Can someone post a small example of code that works with this setting? I am using torchrun.

Below is my code:

EARLY_STOP = torch.tensor(0).float().to(gpu_id)
    if gpu_id == 0:
        early_stop_counter = 0
        rank0_train_losses = []
    
    for epoch in range(best_epoch + 1, max_epochs, 1):
        if gpu_id == 0:
            rank0_train_loss = 0.
        
        train_dl.sampler.set_epoch(epoch)     # To shuffle DistributedSampler data
        for batch_idx, batch in enumerate(train_dl):
            batch = batch.to(gpu_id)
            optimizer.zero_grad()
            out = model(batch)
            loss = torch.nn.MSELoss()(out, batch)
            loss.backward()
            optimizer.step()
            
            if gpu_id == 0:
                rank0_train_loss += loss.item()
        
        
        if gpu_id == 0:
            rank0_train_loss /= (batch_idx + 1)
            rank0_train_losses.append(rank0_train_loss)
            if epoch % train_verbose == 0:
                print(f'rank0_train_loss = {rank0_train_loss} @ epoch#{epoch}')
        
        
        # 8) Validation
        if epoch % val_verbose != 0:
            continue
        
        
        if gpu_id == 0:
            epoch_val_loss = 0.
            model.eval()
            with torch.no_grad():
                for batch in val_dl:
                    batch = batch.to(gpu_id)
                    out = model(batch)
                    loss = torch.nn.MSELoss(reduction='sum')(out, batch)
                    epoch_val_loss += loss.item()

            model.train()

            epoch_val_loss /= len(val_ds)
            print(f'epoch_val_loss = {epoch_val_loss} @ epoch#{epoch} .... min_val_loss = {min_val_loss} @ epoch#{best_epoch}')
            
            
            if epoch_val_loss < min_val_loss:
                early_stop_counter = 0
                best_epoch = epoch
                min_val_loss = epoch_val_loss
                snapshot = {
                    'MODEL_STATE': deepcopy(model.module.state_dict()),
                    'BEST_EPOCH': best_epoch,
                    'MIN_VAL_LOSS': min_val_loss
                }
                torch.save(snapshot, snapshot_path)
            else:
                early_stop_counter += 1
                if early_stop_counter == early_stop_patience:
                    #EARLY_STOP += 1
                    EARLY_STOP.fill_(1)
        
        
        # Broadcast the value of EARLY_STOP from GPU 0 to all other GPUs
        torch.distributed.broadcast(EARLY_STOP, src=0)
        
        #EARLY_STOP = torch.distributed.all_reduce(EARLY_STOP, op=torch.distributed.ReduceOp.MAX)
        
        if EARLY_STOP.item() > 0.5:
            break

For those who are interested, the problem seems to have been solved just by typing model.module(batch) instead of model(batch) in the validation loop. I think this is because the processing is done on a single GPU so we should use the underlying torch.nn.Module model, not the DDP wrapped model. I am not sure of this explanation though :slight_smile: