Distributed Data Parallel doesn't remove hooks

The Docs on the DDP class say:

Forward and backward hooks defined on module and its submodules 
won’t be invoked anymore, unless the hooks are initialized in the forward() method.

However, they are working for me. Did I misinterpret the warning? Does it mean something else?
Is there a reason I should be careful when using hooks with DDP? In my original code I use it to modify the output of some layers. But here’s a simplified minimum reproducible example:

import torch.nn as nn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1))

    def forward(self, x):
        self.model(x)


def hook_fn(mod, inp, out):
    if dist.is_initialized():
        print(
            f"The forward hook is not removed in rank : {dist.get_rank()}")


def run_model(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=2)
    model = SimpleCNN().to(rank)
    for mod in model.modules():
        if 'Conv' in mod.__class__.__name__:
            mod.register_forward_hook(hook_fn)

    ddp_mod = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    for _ in range(2):
        ddp_mod(torch.randn(1, 1, 10, 10))
    dist.destroy_process_group()


def run():
    mp.spawn(run_model, nprocs=2, join=True)


if __name__ == '__main__':
    run()

And the corresponding output:

The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 0
The forward hook is not removed in rank : 1
The forward hook is not removed in rank : 1

@ptrblck could you help please? :slight_smile:

I think based on this comment the issue would only arise if you would be using the single process - multiple devices approach (not recommended) and hooks should work fine for single process - single device (your approach and the recommended one).