Init parameter depend on data during first forward

i have a parameter, and it’s initialized during first forward by:

def forward(self, x):
    if not self.initialized:
        a = nn.Parameter(torch.mean(x))
        self.initialized = True

note that, x is input data,

so the problem is,

when using ddp, batch data is different on each devices, so parameter a is initialized differently on each devices. this can not be solved by setting random seed, i think. it’s more like a sync problem.

how can i initialize them all the same?

thanks for your help!

i added more information. i think it cannot be solved by setting random seed.

Sorry I have misinterpreted your case. I though a was referring to a module parameter but x refers to a batch of your dataset.

I don’t see any straightforward solution as the batch of inputs will inevitably vary per device. Although, you could get a rough estimate by doing like so:

  • register a as a parameter when instantiating your model instead of defining it in the forward method:
class Model(torch.nn.Module):

    def __init__(self, a_init):
        self.a = torch.nn.Parameter(a_init)
  • instantiate a while initializing your model through the following :
a_init = torch.mean(x) / n_batch
model = Model(a_init)

where n_batch denotes your number of batches and x is your whole dataset. This way you will instantiate a as an average of your dataset, rescaled by the number of batches.

1 Like

Distributed collectives such as broadcast/allreduce can help you initialize them in the same way across different workers.

For example, you can use broadcast to make the value the same as rank 0’s across all ranks, or allreduce to take the mean value across all ranks. The different collectives are documented here: Distributed communication package - torch.distributed — PyTorch 1.10.0 documentation

1 Like

thanks, this solved my problem. :+1:

def forward(self, x):
    if not self.initialized:
        mean_x = torch.mean(x)
        torch.distributed.all_reduce(mean_x, torch.distributed.ReduceOp.SUM)
        a = nn.Parameter(mean_x / torch.distributed.get_world_size())
        self.initialized = True
1 Like