i have a parameter, and it’s initialized during first forward by:
def forward(self, x):
if not self.initialized:
a = nn.Parameter(torch.mean(x))
self.initialized = True
note that, x is input data,
so the problem is,
when using ddp, batch data is different on each devices, so parameter a is initialized differently on each devices. this can not be solved by setting random seed, i think. it’s more like a sync problem.
Sorry I have misinterpreted your case. I though a was referring to a module parameter but x refers to a batch of your dataset.
I don’t see any straightforward solution as the batch of inputs will inevitably vary per device. Although, you could get a rough estimate by doing like so:
register a as a parameter when instantiating your model instead of defining it in the forward method:
class Model(torch.nn.Module):
def __init__(self, a_init):
super().__init__()
self.a = torch.nn.Parameter(a_init)
instantiate a while initializing your model through the following :
a_init = torch.mean(x) / n_batch
model = Model(a_init)
where n_batch denotes your number of batches and x is your whole dataset. This way you will instantiate a as an average of your dataset, rescaled by the number of batches.