Ok, I took the day to debug, many informations here:
(1) The drop in max batch size varies. And ofc when I increase the batch size, the training is longer so it doesn’t solve the problem
(2) I managed to rollback my nvidia-driver. I followed this tutorial ( apt - Ubuntu 18.10 : How can I install a specific NVIDIA drivers version? - Ask Ubuntu ), but they forgot to say that I also had to purge libnvidia-*
(3) I managed to try with my old nvidia-driver: same problem with PT1.5
(4) I still had warnings I didn’t have before, so I tried PT1.4 and now it works fine and I’m back at 0.45s/it, with no warnings so I guess it’s the version I was using.
(5) If I want to run pt1.7 back, it says “NVIDIA driver on your system is too old” so I can’t benchmark easily, but it works with 1.5.1 and I can see the slowdown just by changing my conda env
Conclusion: The slowdown was in Pytorch 1.5.1 and is not present in Pytorch 1.4.0, on my architecture (hardware, driver, software)
Pic of score = f(time): Imgur: The magic of the Internet
(ps: the lower score is just due to the steady slower training, when I compare score = f(iteration) it’s all the same except hard orange which has batchsize200 (didn’t adjust lr) + I did two other exps which confirm the pic)
Note: I’m using some stuff in this code: DataParallel, some jit functions, apex, so it would be hard for me to know what is causing the slowdown, I could perhaps try basic benchmarks if it’s important to investigate. Maybe my apex is just not working for pt1.5+ (but it’s the same code and the warnings are for completely different things, I solved the warnings and kept the slowndown)
Setups:
OLDER VERSION
- Driver: Nvidia 435
- Cuda: 10.1 (conda command I executed to install pytorch in late feb 2020 clearly says 10.1)
- Cudnn: 7.6.5 (maybe? It’s a package I have but I also just ran the standard command to install pytorch)
- Pytorch: 1.4 !!! (not 1.5. I’m saying that because I reproduced the behavior in 1.4, I didn’t log my old pt version, I know it was below 1.6 but I installed it in feb2020 and pt1.5 is from april2020, and also 1.5.1 gives me warning I didn’t have)
- Batch size (actually, it’s the batch size I used, I don’t think it was the max): 235
- Time_per_iteration (steady during 2 days of training +/- 0.02s): .45s
- note: This version is from feb2020, so anything I had installed (or almost) is from feb2020, so it wasn’t pt1.5)
AFTER UPDATE
- Driver: Nvidia 460.32.03
- Cuda: 11.2
- Cudnn: 8.0.5
- Pytorch: 1.7.1
- Maximum batch size (before OOM): 216 / 224
- Time_per_iteration: .50s / 0.51s
- conda: pytorch 1.7.1 py3.7_cuda11.0.221_cudnn8.0.5_0
Pytorch rollback
- Driver: Nvidia 460.32.03
- Cuda: 10.1 in conda package but 11.2 in nvidia-smi
- Cudnn: 7.6.3 (according to the conda package)
- Pytorch: 1.5.1
- Maximum batch size (before OOM): 235
- Time_per_iteration: .52s
- conda: pytorch 1.5.1 py3.7_cuda10.1.243_cudnn7.6.3_0
- note: Result are the same with Nvidia 435
Full rollback pt1.4 nvidia435 (nvidia 460 would probably be fine)
- Driver: Nvidia 435.21
- Cuda: 10.1
- Cudnn: 7.6.3
- Pytorch: 1.4.0
- Maximum batch size (before OOM): 312 (yeah) | But all my experiments are done with 235
- Time_per_iteration: .45s
- conda: pytorch 1.4.0 py3.7_cuda10.1.243_cudnn7.6.3_0
Overall: when I just have two different conda env, side by side in two consoles, one with pt1.5, the other with pt1.4, I can see the slowdown between the two consoles (delta_t is quite stable)
For the batch size, I can reach up to 312 with pt1.4 but all the plots / other experiments are done with bs=235 (or the maximum available: 224 for 1.7) except hard orange plot.
The “Time_per_iteration” almost doesn’t vary during training that’s why I can tell there’s a 10+% slowdown (all the plots I did on my side confirm this), it’s not jumping +/- 0.10, it’s +/- 0.02 so the 0.05s difference is clear.
I checked the pytorch version with “import torch; torch.__ version__” to be sure but when I say “conda” in the setups it’s literally “conda list | grep pytorch” so there’s no doubt for me.
The difference in maximum batch size makes me think that it could be a problem with how fp16 is processed.
I’m glad that I solved my problem but I’m sad that my code is slower if I want to update. Though maybe it’s slower just with the methods I used and updating my code with native fp16 & co would make it faster?