How to use "break" in DistributedDataParallel training

flag_tensor = torch.zeros(1).to(device)
if local_rank = 0:
----------#Conditions for breaking the loop :
-------------------flag_tensor += 1
dist.all_reduce(flag_tensor,op=ReduceOp.SUM)
if flag_tensor == 1:
----------print(“Training stopped”)
----------break

1 Like