I tried to do that but it doesnt work. You said that other proicesses would see signal.item() < world_size, but the first process shows the signal to be 0. So it never blocks on all_reduce. Can you tell me where im going wrong?
if self.id == 0:
ep = 10
sl = 0.1
elif self.id == 1:
ep = 15
sl = 0.2
for i in range(0, ep):
# if i == 10:
# break
signal = torch.tensor([1]).to(self.device)
work = dist.all_reduce(signal, async_op=True)
#self._log.info(self.id, "image.. " + str(i))
image = torch.zeros((1,2,3,256,384)).cuda(self.id)
input = {"rgb": image}
output = self.model(input)
loss = torch.sum(output["output"][-1])
work.wait()
self._log.info(self.id, "SIG: " + str(signal.item()))
if signal.item() < self.cfg.mp.workers:
self._log.info(self.id, "EXIT: " + str(signal.item()))
break
self.optimizer.zero_grad()
#self._log.info(self.id, "backward.. " + str(i))
loss.backward()
self.optimizer.step()
#print("step: " + str(self.id) + " " + str(i))
self._log.info(self.id, "done " + str(i))
time.sleep(sl)
self._log.info(self.id, "reduce.. ")
signal = torch.tensor([0]).to(self.device)
self._log.info(self.id, "FIANL SIGNAL: " + str(signal.item()))
if signal.item() >= self.cfg.mp.workers:
dist.all_reduce(signal) # NOT BLOCKING HERE
self._log.info(self.id, "barrier.. ")
dist.barrier()
Yep, because BatchNorm would trigger DDP comm in forward as well. In that case, need to move the signal checking before forward, but it will be slower. The following code should work.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Sequential(
nn.Linear(10, 10),
nn.BatchNorm1d(10),
nn.Linear(10, 10),
).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for _ in range((rank + 1) * 10):
signal = torch.tensor([1])
dist.all_reduce(signal)
if signal.item() < world_size:
break
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
if signal.item() >= world_size:
dist.all_reduce(torch.tensor([0]))
dist.barrier()
print(f"{rank} done")
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
If you know the number of inputs before entering the for loop, you can use an allreduce to get the min of that number across all processes. Then let all processes to just loop that min number of times in the loop.
If this number cannot be acquired or is too expensive to acquire, you can first retrieve one input/batch/sample from the source/dataloader before entering the loop, say it’s called last_input. Then let each iteration
retrieve a new input (call it curr_input) from the dataloder
launch allreduce to check if all processes have a curr_input
run forward-backward-opt on last_input
wait on the all_reduce of the curr_input
if all processes have curr_input proceed, assign curr_input to last_input; otherwise break.
It doesn’t seem to work on my Ampere series GPU anymore with Pytorch 1.11 and CUDA 11.5. This chunk of code blocks on signal.item() when I run with 2 or more GPU
for train_datum in self.train_dataset.iterate():
if early_stop:
break
# If any dataloader in any process stops, we get a signal
if self.cfg.distributed.enabled:
signal = torch.tensor([1]).to(self.device)
torch.distributed.all_reduce(signal)
if signal.item() < self.cfg.distributed.workers: # BLOCKS ON SIGNAL.ITEM()
self.logger.info(self.id, "Termination Signal")
early_stop = True
continue