Hi,
I tried to use torch.no_grad() with DDP, but it would throw
This error indicates that your
module has parameters that were not used in producing its output (the
return value of forward). You can enable unused parameter detection
by passing the keyword argument find_unused_parameters=True to
torch.nn.parallel.DistributedDataParallel
the pseudo code is as following
class MyModel(nn.Module):
def forward(self, x):
with torch.no_grad():
self.layers(x)
return x
class WholeFlow(nn.Module):
def __init__(self):
self.f=MyModel()
self.g=nn.Linear(256, 256)
def forward(self, x):
x=self.f(x)
x=self.g(x)
return x
SGD=(WholeFlow.g.parameters(),...)
There is a similar issue here: DDP does not work well with `torch.no_grad()` in 1.2 · Issue #6087 · PyTorchLightning/pytorch-lightning · GitHub
It works with DataParallell, but can’t work with DDP.
Any idea?