What's the right way to update a field of a DistributedDataParallel object during training?

What is the correct way to update a class variable once the model has been wrapped around a DistributedDataParallel?

In the case below, if we take a snapshot of self.k for each model in each gpu, at the same time, we can get different results.

Any idea why that happens and how to solve that?

I guess loss would be different across all the models?

class Model(torch.nn.Module):
    def __init__(self):
        self.fc = torch.nn.Linear(128, 128)
        self.register_buffer("k", 0)
        self.callback = lambda x: x + 1

    def forward(self, x):
        return self.fc(x)

    def training_step(self, x, y):
        y_hat = self(x)
        loss = torch.nn.functional.binary_cross_entropy(y_hat, y)
        if loss > 1.:
            self.callback(loss, self.k)
1 Like

Hey @epignatelli

When did you take the snapshot of self.k? DDP will broadcast all buffers from rank 0 process to other processes right before calling Mode.forward. See the code below. So given the above code, the buffer should be consistent across all processes before the self.callback is launched.

1 Like

BTW, could you please add a “distributed” tag to distributed-training related questions? So that people working on it can get back to you promptly. Thanks!