The Docs on the DDP class say:
Forward and backward hooks defined on module and its submodules
won’t be invoked anymore, unless the hooks are initialized in the forward() method.
However, they are working for me. Did I misinterpret the warning? Does it mean something else?
Is there a reason I should be careful when using hooks with DDP? In my original code I use it to modify the output of some layers. But here’s a simplified minimum reproducible example:
import torch.nn as nn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1))
def forward(self, x):
self.model(x)
def hook_fn(mod, inp, out):
if dist.is_initialized():
print(
f"The forward hook is not removed in rank : {dist.get_rank()}")
def run_model(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=2)
model = SimpleCNN().to(rank)
for mod in model.modules():
if 'Conv' in mod.__class__.__name__:
mod.register_forward_hook(hook_fn)
ddp_mod = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
for _ in range(2):
ddp_mod(torch.randn(1, 1, 10, 10))
dist.destroy_process_group()
def run():
mp.spawn(run_model, nprocs=2, join=True)
if __name__ == '__main__':
run()
And the corresponding output:
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 1