Calculating Training Loss in DDP

I have the following basic average loss calculation in my training loop:

def train_one_epoch(model, criterion, optimizer, train_loader):
    model.train()
    running_loss = 0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device='cuda')
        out = model(data)
        loss = criterion(out, data.labels)
        loss.backward()
        optimizer.step()
        # Do I need to reduce/gather from all distribution processes?
        running_loss += loss.detach()
    return running_loss / (i + 1)

My question is about the running_loss reported in DDP vs Single-GPU training. Do I need to apply an all_gather or all_reduce operation when running this training loop in DDP?

2 Likes

Hi @Alec-Stashevsky, DDP keeps loss local and averages the gradients. How are you using this running loss? If it is just for reporting the loss on each node then this seems okay, each rank will have report a different running loss.

1 Like

@H-Huang thanks for reply! I am actually using it to report training loss across all ranks at the end of each training epoch. In that case it sounds like I would need to use some type of reduce or gather operation to sync across all devices? I only log these things on the rank 0 process. But currently, I am using the default loss reduction (reduction='mean') and that doesn’t produce a tensor. Unsure about that as well.

@H-Huang any update? @BraveDistribution

Hello. I would also be interested in case there are any updates.
In particular, in my case, I am using NLLLoss.

Did you ever work out how to do this?

Does using DDP hooks mentioned here: DDP Communication Hooks — PyTorch 2.5 documentation help?

Thank you, Hugo. Do you know where I could find a code example that uses DDP communication hooks to compute the local loss equivalent for training (or validation)?

I think the loss you get in DDP (before calling backward) is just the local loss?

I am looking for the local-equivalent (without using DDP) loss (an overall loss across all processes) rather than any individual local loss.

You can just have an allreduce or allgather to manually aggregate the loss as you prefer. This doesn’t have to couple with DDP.

1 Like

yes, tanx, actually, I managed to compute the local equivalent losses.