Within a mini-batch, different types of loss are calculated according to some conditions.
Because of this, synchronization does not match and the gpu becomes stuck.
I tried to solve this using
dist.barrier(), but it didn’t work.
How can I sync properly?
total_loss = 0 for i, idx in enumerate(idx_in_batch) if condition_A: loss_A = ... total_loss = total_loss + loss_A else: loss_B = ... total_loss = total_loss + loss_B total_loss /= imgs.size(0) print(total_loss) # dist.barrier() losses.update(total_loss.item(), imgs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() total_loss.backward() optimizer.step()