I’m fairly new to FSDP. My set up is that I have 2 networks, model
and avg_model
, where avg_model
is updated to be an exponentially moving average (EMA) of model
. Without FSDP, I can trivially duplicate the ema update on all gpus.
What I am trying to achieve, is to wrap both model
and avg_model
in FSDP (1) identically, and then in each training iteration only perform a local EMA update, where each gpu only updates its local portion of the avg_model
using the local parameters of the model
which should be possible due to both being sharded identically.
I have been able to shard the two networks identically by using the same fsdp config and wrapping policy. Within my training iteration, I need to perform an exponentially moving average (EMA) update of one network using the other, where my ema_update
function looks roughly like this:
ema_update(model, avg_model, eta):
avg_model_params = dict(get_unwrapped_model(avg_model).named_parameters())
for param_name, model_param in model_params.items():
avg_model_param.copy_(eta * avg_model_param + (1 - eta) * model_param)
Interestingly, when I run this ema_update
function, along with a torch.distributed.barrier(), I don’t get the same results (within FSDP non-determinism), unless I run the following summon all params after my ema update:
ema_update(model, avg_model, eta)
torch.distributed.barrier()
with FSDP.summon_all_params(model, writeback=True, rank0_only=False):
pass
It seems like the summon_all_params
context manager should not be doing anything, apart from all gathering the parameters and then continuing, it does change the outcome of training and it ends up being necessary to achieve parity. Without it, it does appear as though the avg_model
weights are being updated separately on each gpu, but the results/losses are slightly off from what I expect. How can this behaviour be explained?
It i worth bearing in mind that there are other parts to my training loop where the normal non-avged model
is updated via a typical adam optimizer.