Can i check. Instead of using this approach:
Can I create a shared memory torch tensor. And as soon as one process finishes its batch, i can write to the tensor, and have the other processes read it so it can break it’s loop
Can i check. Instead of using this approach:
Can I create a shared memory torch tensor. And as soon as one process finishes its batch, i can write to the tensor, and have the other processes read it so it can break it’s loop
Sure, that will also work. You can use a torch.mutliprocessing.Queue
to do that, but this solution cannot scale across multiple machines.
Final question. Not clear why you have another all reduce at the end
Hmm…i tried this method and it works. But now the second process hangs at dist.barrier().
Final question. Not clear why you have another all reduce at the end
That’s to indicate one process has exited the loop, otherwise the sum would never be smaller than the world_size. It needs to be guarded by a flag, because only the processes that exit early need that. Sth like:
for batch in get_batch():
....
if signal.item() >= world_size:
all_reduce(torch.tensor([0]))
I tried implement it as you suggested, but my second process gets stuck on the barrier
if self.id == 0:
ep = 10
sl = 0.1
elif self.id == 1:
ep = 15
sl = 0.2
for i in range(0, ep):
# if i == 10:
# break
signal = torch.tensor([1]).to(self.device)
work = dist.all_reduce(signal, async_op=True)
self._log.info(self.id, "image.. " + str(i))
image = torch.zeros((1,2,3,256,384)).cuda(self.id)
input = {"rgb": image}
output = self.model(input)
loss = torch.sum(output["output"][-1])
work.wait()
if signal.item() < self.cfg.mp.workers:
break
self.optimizer.zero_grad()
self._log.info(self.id, "backward.. " + str(i))
loss.backward()
self.optimizer.step()
print("step: " + str(self.id) + " " + str(i))
self._log.info(self.id, "done " + str(i))
time.sleep(sl)
self._log.info(self.id, "reduce.. ")
signal = torch.tensor([0]).to(self.device)
dist.all_reduce(signal)
self._log.info(self.id, "barrier.. ")
dist.barrier()
return
the final all reduce should only be done by the process that first exits. How do i ensure that?
You can guard that with a flag. For the first process that exits (call it X), its signal
tensor should always equal to world_size
. But for other processes, they would see signal.item() < world_size
, because their all_reduce
in the for loop will join with the after-loop all_reduce
on X.
I tried to do that but it doesnt work. You said that other proicesses would see signal.item() < world_size
, but the first process shows the signal to be 0. So it never blocks on all_reduce. Can you tell me where im going wrong?
if self.id == 0:
ep = 10
sl = 0.1
elif self.id == 1:
ep = 15
sl = 0.2
for i in range(0, ep):
# if i == 10:
# break
signal = torch.tensor([1]).to(self.device)
work = dist.all_reduce(signal, async_op=True)
#self._log.info(self.id, "image.. " + str(i))
image = torch.zeros((1,2,3,256,384)).cuda(self.id)
input = {"rgb": image}
output = self.model(input)
loss = torch.sum(output["output"][-1])
work.wait()
self._log.info(self.id, "SIG: " + str(signal.item()))
if signal.item() < self.cfg.mp.workers:
self._log.info(self.id, "EXIT: " + str(signal.item()))
break
self.optimizer.zero_grad()
#self._log.info(self.id, "backward.. " + str(i))
loss.backward()
self.optimizer.step()
#print("step: " + str(self.id) + " " + str(i))
self._log.info(self.id, "done " + str(i))
time.sleep(sl)
self._log.info(self.id, "reduce.. ")
signal = torch.tensor([0]).to(self.device)
self._log.info(self.id, "FIANL SIGNAL: " + str(signal.item()))
if signal.item() >= self.cfg.mp.workers:
dist.all_reduce(signal) # NOT BLOCKING HERE
self._log.info(self.id, "barrier.. ")
dist.barrier()
The self-contained code below works for me.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for _ in range((rank + 1) * 10):
signal = torch.tensor([1])
work = dist.all_reduce(signal, async_op=True)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
work.wait()
if signal.item() < world_size:
break
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
if signal.item() >= world_size:
dist.all_reduce(torch.tensor([0]))
dist.barrier()
print(f"{rank} done")
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
Thanks i finally made it work with your help!
This doesnt work anymore if i use BatchNorm with GPU. If i disable track running stats then it works
It hangs on the line
if signal.item() >= world_size:
dist.all_reduce(torch.tensor([0]))
Yep, because BatchNorm
would trigger DDP comm in forward
as well. In that case, need to move the signal checking before forward, but it will be slower. The following code should work.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Sequential(
nn.Linear(10, 10),
nn.BatchNorm1d(10),
nn.Linear(10, 10),
).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for _ in range((rank + 1) * 10):
signal = torch.tensor([1])
dist.all_reduce(signal)
if signal.item() < world_size:
break
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
if signal.item() >= world_size:
dist.all_reduce(torch.tensor([0]))
dist.barrier()
print(f"{rank} done")
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
Its much slower. GPU Utilization is at 50 percent now. Is there any alternative solution to end process?
Ok i stand corrected. Slowdown comes from SyncBatchNorm
Is there any alternative solution to end process?
We are working on a more elegant and efficient solution. See the tracking issue: [RFC] Join-based API to support uneven inputs in DDP · Issue #38174 · pytorch/pytorch · GitHub
To unblock,
allreduce
to get the min of that number across all processes. Then let all processes to just loop that min number of times in the loop.last_input
. Then let each iteration
curr_input
) from the datalodercurr_input
last_input
all_reduce
of the curr_input
curr_input
proceed, assign curr_input
to last_input
; otherwise break.Just want to confirm is it SyncBatchNorm
or BatchNorm
?
Okay, its slower when using SyncBatchNorm, but it is also slightly slower when using no BatchNorm at all
But maybe for some reason, we don’t have any input for a GPU device, or we just skip all of input for one device of them.
How can we use the last_input
for all_reduce
operation?
It doesn’t seem to work on my Ampere series GPU anymore with Pytorch 1.11 and CUDA 11.5. This chunk of code blocks on signal.item()
when I run with 2 or more GPU
for train_datum in self.train_dataset.iterate():
if early_stop:
break
# If any dataloader in any process stops, we get a signal
if self.cfg.distributed.enabled:
signal = torch.tensor([1]).to(self.device)
torch.distributed.all_reduce(signal)
if signal.item() < self.cfg.distributed.workers: # BLOCKS ON SIGNAL.ITEM()
self.logger.info(self.id, "Termination Signal")
early_stop = True
continue