You could always try instance norm (unless you have really few features per channel), which takes away the difference between training and evaluation.
If you wanted to experiment with batch renorm, starting from current master makes that easier. In it’s native cuda kernels, we split the gathering the statistics from applying the transformation. So you could reuse the first and just adapt the pointwise transformation to use the amended formula. You can use the vanilla backward function (which is the “stop grad” wants you to ignore the amendments for the backward).
I would recommend doing this as a custom op or extension module + torch.autograd.Function.
Best regards
Thomas