What's the right way to update a field of a DistributedDataParallel object during training?

What is the correct way to update a class variable once the model has been wrapped around a DistributedDataParallel?

In the case below, if we take a snapshot of self.k for each model in each gpu, at the same time, we can get different results.

Any idea why that happens and how to solve that?

I guess loss would be different across all the models?

class Model(torch.nn.Module):
    def __init__(self):
        self.fc = torch.nn.Linear(128, 128)
        self.register_buffer("k", 0)
        self.callback = lambda x: x + 1

    def forward(self, x):
        return self.fc(x)

    def training_step(self, x, y):
        y_hat = self(x)
        loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
        if loss > 1.:
            self.callback(loss, self.k)
Hey @epignatelli

When did you take the snapshot of self.k? DDP will broadcast all buffers from rank 0 process to other processes right before calling Mode.forward. See the code below. So given the above code, the buffer should be consistent across all processes before the self.callback is launched.

BTW, could you please add a “distributed” tag to distributed-training related questions? So that people working on it can get back to you promptly. Thanks!