I’m trying to get SWA (EMA) working for my model. But whatever I try, the model I get afterwards does not work. Well, it technically works but the predictions are simply wrong.
The way I currently do it is, that I start from my current best model as checkpoint and train for another 20 epochs with a reduced learning rate. And then I use the averaged model.
I train a mixed-precision model using DDP. I think, the way I calculate the model is actually straightforward:
I was wondering if the way I serialize the model could be a problem, because the layer keys in the model_state_dict is a bit different than the layer names in the not averaged model:
The training runs with the old “DataParallel” mode.
However, the “prediction” runs with “DistributedDataParallel”. So far this has not been a problem, but it turns out that running the averaged model with DDP gives me corrupted results, while running the same model with DP gives me decent results.