Leveraging half-precision training in PPO and Transformer-XL

Hi folks!

It seems that training in different precision is an under-explored area in DRL:

  • There is an ICML Paper that examines SAC in low-precision,
  • and another one that deals quantization-aware training and post-quantization experiments
  • It would be great to know if you knew of more papers =)

My actual problem is centered around the massive memory consumption of my PPO + Transformer-XL (TrXL) implementation. Caching the hidden states, that are output by Transformer-XL, requires lots of memory in the dimensions of (num_episodes, max_episode_steps, num_trxl_layers, trxl_dim). PPO truncates episodes, so I need to buffer hidden states even across epochs.

I’d like to mitigate this memory overhead by leveraging half-precision for these hidden states (either f16 or bf16). What’s the most convenient approach to do this? Shall I cast the tensors from f32 to f16 manually? Shall I use AMP, where I only set the hidden states tensor to a lower precision? Will AMP take effect on other calculations than those with lower precision data, e.g. will it automatically use lower precision for convolutional layers? So far AMP is a little ambiguous to me on what it automatically casts and what not.

Before trying AMP or manually casting hidden states, I’d love to here some thoughts on this. Thanks in advance =)