Hi, I tried using Adam(…,fused=True) and it does yield a 10-15% speedup on 1 GPU, as expected. However, when I try to use it with ddp (via lightning) with 2 GPUs, it yields no speedup at all. Is this expected behavior? I realize that this may be an issue with lightning and not torch, in which case this would not be the appropriate forum, but would just like to check if anyone has experienced similar behavior. (I’m running the latest torch and lightning on a machine with 2 2080tis).
Also, I was wondering if there are any plans to implement fused AdamW? Using AdamW’s weight decay is quite important for transformer training and, assuming there are speedups to be had, it would be great to have a fused version of it.
You might need to profile the actual DDP workload in your Lightning setup and check the kernel execution, where the bottleneck of your training is etc. Based on your observation, the speedup from Adam(... fused=True) might be wiped out by another slowdown.
@crcrpar has implemented fused AdamW in this PR and it should already be available if I’m not mistaken.