Copying params between 2 identically sharded (FSDP) networks

I’m fairly new to FSDP. My set up is that I have 2 networks, model and avg_model, where avg_model is updated to be an exponentially moving average (EMA) of model. Without FSDP, I can trivially duplicate the ema update on all gpus.

What I am trying to achieve, is to wrap both model and avg_model in FSDP (1) identically, and then in each training iteration only perform a local EMA update, where each gpu only updates its local portion of the avg_model using the local parameters of the model which should be possible due to both being sharded identically.

I have been able to shard the two networks identically by using the same fsdp config and wrapping policy. Within my training iteration, I need to perform an exponentially moving average (EMA) update of one network using the other, where my ema_update function looks roughly like this:

ema_update(model, avg_model, eta):
     avg_model_params = dict(get_unwrapped_model(avg_model).named_parameters())
     for param_name, model_param in model_params.items():
         avg_model_param.copy_(eta * avg_model_param + (1 - eta) * model_param)

Interestingly, when I run this ema_update function, along with a torch.distributed.barrier(), I don’t get the same results (within FSDP non-determinism), unless I run the following summon all params after my ema update:

ema_update(model, avg_model, eta)
torch.distributed.barrier()
with FSDP.summon_all_params(model, writeback=True, rank0_only=False):
    pass

It seems like the summon_all_params context manager should not be doing anything, apart from all gathering the parameters and then continuing, it does change the outcome of training and it ends up being necessary to achieve parity. Without it, it does appear as though the avg_model weights are being updated separately on each gpu, but the results/losses are slightly off from what I expect. How can this behaviour be explained?

It i worth bearing in mind that there are other parts to my training loop where the normal non-avged model is updated via a typical adam optimizer.

@Timofey_Abramski did you figure out a solution?

No I haven’t, the best I have come up with is to do the apparently unnecessary

with FSDP.summon_all_params(model, writeback=True, rank0_only=False):
    pass

immediately after doing the local update, and that ensures I get the correct results. This feels kinda hacky, and seems unnecessary on top of the torch.distributed.barrier(), but without it I get incorrect results. It seems like there is some FSDP internal state that the summon forces to be correct, although not sure how this works internally.