Within a mini-batch, different types of loss are calculated according to some conditions.
Because of this, synchronization does not match and the gpu becomes stuck.
I tried to solve this using dist.barrier()
, but it didn’t work.
How can I sync properly?
total_loss = 0
for i, idx in enumerate(idx_in_batch)
if condition_A:
loss_A = ...
total_loss = total_loss + loss_A
else:
loss_B = ...
total_loss = total_loss + loss_B
total_loss /= imgs.size(0)
print(total_loss)
# dist.barrier()
losses.update(total_loss.item(), imgs.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()