I have the following lines in my forward method:
all_image_features1 = all_gather(image_features1)
all_image_features2 = all_gather(image_features2)
print(f"FORWARD (AFTER SYNC): {time.time()}, rank={self.global_rank}")
where all_gather
is a small wrapper around torch’s all_gather
.
I have noticed that the timestamp printed after the sync is not the same in all processes (I am using DistributedDataParallel). It can diverge as much as 8 seconds!
Presumably, all_gather
should force all the training processes to synchronize. Why is this not happening?