How to synchronize to calculate loss in DDP?

Within a mini-batch, different types of loss are calculated according to some conditions.
Because of this, synchronization does not match and the gpu becomes stuck.
I tried to solve this using dist.barrier(), but it didn’t work.
How can I sync properly?

total_loss = 0
for i, idx in enumerate(idx_in_batch)
    if condition_A:
       loss_A = ...
        total_loss = total_loss + loss_A
    else:
        loss_B = ...
        total_loss = total_loss + loss_B

total_loss /= imgs.size(0)
print(total_loss)
# dist.barrier()
losses.update(total_loss.item(), imgs.size(0))

# compute gradient and do SGD step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()