Issue with Training Loop Using DDP and AMP: Process Getting Stuck

Hi everyone,

I’m working on a project that combines Distributed Data Parallel (DDP) and Automatic Mixed Precision (AMP). Forward pass works fine pred = model(x), but the training process is getting stuck for reasons during the Backward pass that I can’t figure out.

Here’s the code for my training loop:

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.cuda.set_device(rank)
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method='env://')


def main(rank, video_path, label_path, world_size, run):
    setup(rank, world_size)

    cfg = config()
    cfg_swin = config_swin()
    
    nb_videos = len(os.listdir(video_path))
    train_indices, val_indices = train_test_split(np.arange(nb_videos), test_size=0.1, random_state=0)

    train_dataset = HEDataset(video_path, label_path, train_indices)
    val_dataset = HEDataset(video_path, label_path, val_indices)
    
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)


    train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=False, pin_memory=False, num_workers=0, sampler=train_sampler)
    val_dataloader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, pin_memory=False, num_workers=0, sampler=val_sampler)

    model = MViT(cfg, cfg_swin).to(rank)
    model = DDP(model, device_ids=[rank])

    scaler = torch.amp.GradScaler('cuda', enabled=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, betas=(0.9, 0.999))
    
    if rank == 0 :
        run.watch(model, log='all')
    for epoch in tqdm(range(cfg.n_epochs), desc='Epochs', disable=rank != 0):
        train_dataloader.sampler.set_epoch(epoch) 
        val_dataloader.sampler.set_epoch(epoch) 

        model = train_he(train_dataloader, model, optimizer, scaler, epoch, cfg, run)
        model = validate_he(val_dataloader, model, scaler, epoch, cfg, run)

    dist.destroy_process_group()
def train_he(train_dataloader, model, optimizer, scaler, epoch, cfg, run):

    loss_list = []
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader, desc='Training loop')):
        x, label = batch

        device = next(model.parameters()).device

        x, label = x.to(device), label.to(device)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            pred = model(x)
            
            # loss 
            loss = HE_loss(label, pred)

        scaler.scale(sum(loss)).backward() ### It stuck here
        scaler.step(optimizer)
        scaler.update()
        
        loss_list.append(sum(loss).item())
    run.log({"Train Loss" : np.mean(loss_list)}, step=epoch)
    print(f"Epoch {epoch+1} - Train Loss: {np.mean(loss_list):.4f}")
    return model

I’d really appreciate it if someone could review the code and let me know if there’s anything wrong with it or if I’m missing something important.

Some additional details about my setup:

  • Environment: Python 3.12.8, PyTorch 2.5.1, CUDA 12.4

Just to double check, removing AMP and keeping everything else the same is fine?

In the tutorial: Automatic Mixed Precision examples — PyTorch 2.5 documentation there isn’t the sum(loss). If your total loss function include many losses, you should use loss.sum().

Yes, I checked it before and after adding AMP, and it works fine.

I made the change as you suggested, but it still isn’t working.
Here is my loss :

loss = torch.stack([BLEFCO1_loss, BLEFCO2_loss, BLEFCO3_loss, decision_loss, born_loss])
scaler.scale(loss.sum()).backward()