We currently have a prototype API _set_static_graph which can be applied to DDP if your training is static across all iterations (i.e. there is no conditional execution in the model). Documentation: pytorch/distributed.py at master · pytorch/pytorch · GitHub.
With static graph training, DDP will record the # of times parameters expect to get gradient and memorize this, which solves the issue around activation checkpointing and should make it work.