Speed up training when large part of training doesn't require grad?

Hello,

I have an atypical training regime - a large part (roughly 80%) of the forward call of the model doesn’t require storing activations. So, I do that by enclosing in no_grad() (there is also parts where this model needs to store activations - that is outside no_grad() context). I was wondering if there is a way I can speed up these computations - since they are almost like inference - so I should be able to use FP16 speedup tricks? (my GPUs do have Tensor cores)

Should I use something like Nvidia Apex (which should switch to FP16 when executed in no_grad() ?)
or should I explicitly convert the model to FP16 (with model.half()) and convert it back to FP32 every time I want to do grad enabled computations?

Thanks!