Using apex for loading checkpoints

Hi, I have a model that I used apex for implementing half-precision of the model. I wanted to know when I load the checkpoints should I load the checkpoints of apex?

If you want to continue the training in apex, then you would have to restore the apex state_dict.
However, we recommend to check out the native mixed-precision implementation as described here, which provides a smoother experience than apex.
Note that you would have to install the nightly binaries or build from source to use it.

Thank you for your answer. Is there anything wrong with using Apex?
Why do you recommend me to use the native mixed-precision implementation?
Actually, I am working with 3d images that are high dimensional. So, I frequently encounter the OOM error.
By using apex everything was fine until some moments ago. But, again after some changes in my code I do not know why I am getting OOM error!

My second question is that, if mixed precision and DistributedDataParallel, can be used at the same time and using both of them can solve the problem of OOM error for larger batches? I tried DataParallel before and it didn’t help.
At the moment, I am working with batches of size 2, but in this batch size I do not observe good behavior of loss function and the loss function dose not decrease. It increase actually. After googling a lot I found that it’s due to batch size and dateset that I have. I tested it for 2D and it worked fine but for 3D I didn’t get good result.

apex.amp was developed as the first implementation for mixed-precision training before @mcarilli implemented the native implementation.
The main advantage is that you can use it by installing the (nightly) binaries without building apex as an extension (which is not always straightforward, if it’s your first time). In the long run we’ll move our focus on native amp and deprecate apex.amp.
Here is also a more detailed description.

It’s hard to tell, which of these changes created the OOM issue suddently without seeing the code.

1 Like