Computing Jacobian Vector Product (jvp) using nn.DataParallel

Hi and thanks for your answers!

I have tried to implement the double-back trick for computing jvps as in:
https://j-towns.github.io/2017/06/12/A-new-trick.html
and here:

and is also in the PyTorch’s original implementation of jvp:
https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html

However, it seems that syncBN continues to make problems.
I get the following error:

File "/home/mitchell/DivideMix/meta_utils.py", line 261, in jacobian_vector_product
    jvp = torch.autograd.grad(g, v, vector)[0]
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::batch_norm_backward_elemt is not implemented

Here is my implementation:

def jacobian_vector_product(model, inputs, vector, method='double-back-trick'):
    
    if method == 'forward':
        '''
        jvp products using forward mode AD as demonstrated in:
        https://pytorch.org/tutorials/intermediate/forward_ad_usage.html
        '''
        params = {name: p for name, p in model.named_parameters()}
        tangents = {name: vec_pt for name, vec_pt in zip(params.keys(), vector)}

        with torch.autograd.forward_ad.dual_level():
            for name, p in params.items():
                del_attr(model, name.split("."))
                set_attr(model, name.split("."), torch.autograd.forward_ad.make_dual(p, tangents[name]))

            out = model(inputs)
            lsm = F.log_softmax(out, dim=1)
            jvp = torch.autograd.forward_ad.unpack_dual(lsm).tangent
            return jvp
    elif method == 'double-back-trick':
        '''
        jvp products using double backward as demonstrated in:
        https://j-towns.github.io/2017/06/12/A-new-trick.html
        https://discuss.pytorch.org/t/forward-mode-ad-vs-double-grad-trick-for-jacobian-vector-product/159037
        '''
        out = model(inputs)
        lsm = F.log_softmax(out, dim=1)
        v = torch.zeros_like(lsm, requires_grad=True)
        g = torch.autograd.grad(lsm, model.parameters(), v, create_graph=True)
        jvp = torch.autograd.grad(g, v, vector)[0]
        return jvp
    else:
        raise NotImplementedError

Am I doing a mistake or is it another issue with syncBn?