Torch.no_grad() with DDP

Hi,

I tried to use torch.no_grad() with DDP, but it would throw

This error indicates that your
module has parameters that were not used in producing its output (the
return value of forward). You can enable unused parameter detection
by passing the keyword argument find_unused_parameters=True to
torch.nn.parallel.DistributedDataParallel

the pseudo code is as following


class MyModel(nn.Module):
    def forward(self, x):
        with torch.no_grad():
            self.layers(x)
        return x

class WholeFlow(nn.Module):
    def __init__(self):
        self.f=MyModel()
        self.g=nn.Linear(256, 256)
    def forward(self, x):
        x=self.f(x)
        x=self.g(x)
        return x

SGD=(WholeFlow.g.parameters(),...)

There is a similar issue here: DDP does not work well with `torch.no_grad()` in 1.2 · Issue #6087 · PyTorchLightning/pytorch-lightning · GitHub

It works with DataParallell, but can’t work with DDP.

Any idea?

Did you try to add the suggested find_unused_parameters=True argument and if so, did you get any other error?

add find_unused_parameters=True works. but is this a bug for DDP?