I think, if you have a DDP model A
, and:
class A(nn.Module):
...
model = A()
model = DDP(model)
Then you must use output = model(input)
at least once, no matter what in this forward function. OR it may cause some problems.
Like I have subsequent code:
# A ResNet18 demo for MNIST:
class ResNet18(nn.Module):
def __init__(self, num_classes: int):
super(ResNet18, self).__init__()
if is_main_process():
self.model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
else:
self.model = torchvision.models.resnet18(weights=None)
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, num_classes)
return
def forward(self, x):
# return self.model(x)
return x.detach() # this forward function does not process any computing
def fake_forward(self, x):
return self.model(x)
def build(config: dict):
return ResNet18(
num_classes=10
)
And in my training function, if I:
outputs = model.module.fake_forward(images.to(device))
The parameters of my model will be inconsistent between different GPUs (nodes).
But If I:
images = model(images) # although do noting!
# or images = model.forward(images)
outputs = model.module.fake_forward(images.to(device))
The parameters will be consistent.
Thus, I think some code deep in the forward()
function is the key to DDP parameter sync.