Does DistributedDataParallel work with `torch.no_grad()` and `find_unused_parameters`=False

HI, in my context I’m forwarding twice to obtain two results, and use one of them to guide the other. suncet/losses.py at master · facebookresearch/suncet · GitHub

target_supports, anchor_supports = encoder(simgs, return_before_head=True)
target_views, anchor_views = encoder(uimgs, return_before_head=True)
# then I use anchor_supports and anchor_views to calculate loss on https://github.com/facebookresearch/suncet/blob/master/src/losses.py#L65

I actually dont need the gradients of anchor_supports, so I added with torch.no_grad() like this:

with torch.no_grad():
    target_supports, anchor_supports = encoder(simgs, return_before_head=True)
target_views, anchor_views = encoder(uimgs, return_before_head=True)

After that, ddp throws runtimeerror and asked me to set find_unused_parameters=True.

I tried setting it to true and it work, but I dont understand why thats necessary.

I also tried this and it worked with find_unused_parameters=False.

target_supports, anchor_supports = encoder(simgs, return_before_head=True)
target_views, anchor_views = encoder(uimgs, return_before_head=True)
anchor_supports = anchor_supports.detach()
target_supports = target_supports.detach()

Is there a way to both use torch.no_grad() to save memory and use find_unused_parameters=False to speed up?

There is not a way to use torch.no_grad and find_unused_parameters=False since these operations are still included in the autograd graph, but only require_grad is set to False. See the find_used_parameters argument for DDP DistributedDataParallel — PyTorch master documentation

When you are detaching anchor_supports and target_supports then the autograd graph will not be tracking those operations, which is why the example you gave is working with find_unused_parameters=False.

1 Like