wconstab
(Will Constable)
December 12, 2023, 6:26pm
1
It can be tricky to use python debugger from a multi-rank setup. The first thing you’d notice if you try this is that pdb may crash your program if you use it from inside a mpirun or torchrun launcher. Fortunately, this is fixable and you can use pdb almost like usual.
There is a catch- it’s not too easy to attach the debugger on each rank, but it’s pretty easy to attach it to just one particular rank (and let all the other ranks pause).
This PR from @ezyang adds a new helper called torch.distributed.breakpoint
. It can be used more or less like python’s breakpoitn statement, except you’re supposed to have it called on all ranks (but always pass the same int for rank, so across all ranks one rank in particular is the one that will listen for the debugger input).
pytorch:gh/ezyang/2444/base
← pytorch:gh/ezyang/2444/head
opened 05:03PM - 15 Nov 23 UTC
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __… ->__ #113775
I tested it works by patching
```
diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index 96b3a82bdfa..dea9bac9302 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -18,6 +18,7 @@ from torch._dynamo import config
from torch._dynamo.utils import same
from torch._dynamo.testing import collect_results
from torch.utils._triton import has_triton
+import torch.distributed as dist
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, lambda_auto_wrap_policy
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -398,6 +399,7 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_activation_checkpointing(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
+ dist.breakpoint()
model, inputs = get_toy_model_for_activation_checkpointing(f"cuda:{self.rank}")
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
```
and then running `python test/distributed/test_dynamo_distributed.py -k test_fsdp_activation_checkpointing`
It prints:
```
ATTENTION!!!
Type 'up' to get to the frame that called dist.breakpoint(rank=0)
> /data/users/ezyang/c/pytorch/torch/distributed/__init__.py(71)breakpoint()
-> barrier()
(Pdb) up
> /data/users/ezyang/c/pytorch/test/distributed/test_dynamo_distributed.py(402)test_fsdp_activation_checkpointing()
-> dist.breakpoint()
(Pdb) list
397
398 @skip_if_lt_x_gpu(1)
399 @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
400 def test_fsdp_activation_checkpointing(self):
401 with _dynamo_dist_per_rank_init(self.rank, self.world_size):
402 -> dist.breakpoint()
403 model, inputs = get_toy_model_for_activation_checkpointing(f"cuda:{self.rank}")
404 is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
405 wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
406 model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner)
407 correct_outputs = model(inputs)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
3 Likes