Computing Jacobian Vector Product (jvp) using nn.DataParallel

As I undersrand, problem is that sync_batch_norm (which inherits from autograd.Function) is the one that is missing the jvp method (this can be observed here: pytorch/_functions.py at master · pytorch/pytorch · GitHub, note that SyncBatchNorm implements forward and backward, but not jvp).

Am I expected to implement the jvp call for SyncBatchNorm myself? It seems too complicated…
Am I missing something?
Are there any workarounds? (I.e using normal batch norm or the double grad trick(?) Forward Mode AD vs double grad trick for jacobian vector product)