As I undersrand, problem is that sync_batch_norm
(which inherits from autograd.Function
) is the one that is missing the jvp
method (this can be observed here: pytorch/_functions.py at master · pytorch/pytorch · GitHub, note that SyncBatchNorm
implements forward and backward, but not jvp).
Am I expected to implement the jvp
call for SyncBatchNorm
myself? It seems too complicated…
Am I missing something?
Are there any workarounds? (I.e using normal batch norm or the double grad trick(?) Forward Mode AD vs double grad trick for jacobian vector product)