Fused Adam slow, request for fused AdamW

Hi, I tried using Adam(…,fused=True) and it does yield a 10-15% speedup on 1 GPU, as expected. However, when I try to use it with ddp (via lightning) with 2 GPUs, it yields no speedup at all. Is this expected behavior? I realize that this may be an issue with lightning and not torch, in which case this would not be the appropriate forum, but would just like to check if anyone has experienced similar behavior. (I’m running the latest torch and lightning on a machine with 2 2080tis).

Also, I was wondering if there are any plans to implement fused AdamW? Using AdamW’s weight decay is quite important for transformer training and, assuming there are speedups to be had, it would be great to have a fused version of it.

Thank you so much in advance!

You might need to profile the actual DDP workload in your Lightning setup and check the kernel execution, where the bottleneck of your training is etc. Based on your observation, the speedup from Adam(... fused=True) might be wiped out by another slowdown.

@crcrpar has implemented fused AdamW in this PR and it should already be available if I’m not mistaken.

Right, that makes sense, I will look into it.

I checked now and indeed it is available! Apologies for this, I recently upgraded to torch 2, I suppose the previous version didn’t have fused AdamW yet.

Thanks a lot for the help and kudos to you for all the hard work you do for the community!

1 Like