Transformers in O(log(n))

A recent paper by Google show it’s possible to greatly reduce the memory footprint of transformers.

If and when would something like this be implemented in Pytorch?