Computing Jacobian Vector Product (jvp) using nn.DataParallel

Hello :slight_smile:
I have an application which requires computing the derivative of the model’s logits (more accurately - the log softmax of them) with respect to the model’s parameters, multiplied by a constant vector.
Theoretically, this can be done exactly by a single forward pass, as demonstrated in the following PyTorch tutorial: Forward-mode Automatic Differentiation (Beta) — PyTorch Tutorials 1.13.1+cu117 documentation.

Following this tutorial I had managed to write a version of this code for my use case:

def jacobian_vector_product(model, inputs, vector):
    params = {name: p for name, p in model.named_parameters()}
    tangents = {name: vec_pt for name, vec_pt in zip(params.keys(), vector)}

    with torch.autograd.forward_ad.dual_level():
        for name, p in params.items():
            del_attr(model, name.split("."))
            set_attr(model, name.split("."), torch.autograd.forward_ad.make_dual(p, tangents[name]))

        out = model(inputs)
        lsm = F.log_softmax(out, dim=1)
        jvp = torch.autograd.forward_ad.unpack_dual(lsm).tangent
        return jvp

(Note here that the original setattr and delattr were replaced by a nested version of them).

Indeed, the above code works perfectly when I don’t use any kind of data parallelism.

However, when I wrap the model with nn.DataParallel(), the forward pass in this function yields an error indicating that the model and the inputs are not on the same device.
After a quick debugging I found out that the batch indeed splits to multiple chunks, and that the input is distributed correctly, however all the forward passes are ocuuring on cuda:0 instead of being distributed correspondingly to the inputs.

I think the issue has to do something with the attribute editing loop.
This is because when I tried to comment out the attribute editing loop - forward pass worked properly (although, of course, the jvp equals None).
I have also tried to wrap the model with DataParallel after editing the attributes but I get the same error.

I would appreciate any assistance on this,
Thanks!

I guess the newly added features such as forward-mode AD might not be supported and tested in e.g. nn.DataParallel as it’s planned to be deprecated. We generally recommend to use DistributedDataParallel instead which should also perform better than DataParallel.

Hi, and thanks for the very quick reply!

After changing DataParallel to DDP I get the following error in my function:

NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.

Tracing the error, seems that the Synchronized Batch Normalization (SyncBN) in my model is causing the problem.

File "/home/mitchell/DivideMix/meta_utils.py", line 241, in jacobian_vector_product
    out = model(inputs)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torchvision/models/resnet.py", line 285, in forward
    return self._forward_impl(x)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torchvision/models/resnet.py", line 269, in _forward_impl
    x = self.bn1(x)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 748, in forward
    return sync_batch_norm.apply(
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/autograd/function.py", line 257, in apply_jvp
    return self._forward_cls.jvp(self, *args)  # type: ignore[attr-defined]
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/autograd/function.py", line 391, in jvp
    raise NotImplementedError("You must implement the jvp function for custom "
NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.

What can I do to resolve it?

Thanks in advance!

It seems the static jvp method is missing in an autograd.Function and you would need to implement it as described here.

As I undersrand, problem is that sync_batch_norm (which inherits from autograd.Function) is the one that is missing the jvp method (this can be observed here: pytorch/_functions.py at master · pytorch/pytorch · GitHub, note that SyncBatchNorm implements forward and backward, but not jvp).

Am I expected to implement the jvp call for SyncBatchNorm myself? It seems too complicated…
Am I missing something?
Are there any workarounds? (I.e using normal batch norm or the double grad trick(?) Forward Mode AD vs double grad trick for jacobian vector product)

Yes, you would either need to implement jvp for SyncBatchNorm or fall back to the standard nn.BatchNormXd layers (assuming they are providing this method), which would then not synchronize the running stats but would hopefully work otherwise.

Hi and thanks for your answers!

I have tried to implement the double-back trick for computing jvps as in:
https://j-towns.github.io/2017/06/12/A-new-trick.html
and here:

and is also in the PyTorch’s original implementation of jvp:
https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html

However, it seems that syncBN continues to make problems.
I get the following error:

File "/home/mitchell/DivideMix/meta_utils.py", line 261, in jacobian_vector_product
    jvp = torch.autograd.grad(g, v, vector)[0]
  File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::batch_norm_backward_elemt is not implemented

Here is my implementation:

def jacobian_vector_product(model, inputs, vector, method='double-back-trick'):
    
    if method == 'forward':
        '''
        jvp products using forward mode AD as demonstrated in:
        https://pytorch.org/tutorials/intermediate/forward_ad_usage.html
        '''
        params = {name: p for name, p in model.named_parameters()}
        tangents = {name: vec_pt for name, vec_pt in zip(params.keys(), vector)}

        with torch.autograd.forward_ad.dual_level():
            for name, p in params.items():
                del_attr(model, name.split("."))
                set_attr(model, name.split("."), torch.autograd.forward_ad.make_dual(p, tangents[name]))

            out = model(inputs)
            lsm = F.log_softmax(out, dim=1)
            jvp = torch.autograd.forward_ad.unpack_dual(lsm).tangent
            return jvp
    elif method == 'double-back-trick':
        '''
        jvp products using double backward as demonstrated in:
        https://j-towns.github.io/2017/06/12/A-new-trick.html
        https://discuss.pytorch.org/t/forward-mode-ad-vs-double-grad-trick-for-jacobian-vector-product/159037
        '''
        out = model(inputs)
        lsm = F.log_softmax(out, dim=1)
        v = torch.zeros_like(lsm, requires_grad=True)
        g = torch.autograd.grad(lsm, model.parameters(), v, create_graph=True)
        jvp = torch.autograd.grad(g, v, vector)[0]
        return jvp
    else:
        raise NotImplementedError

Am I doing a mistake or is it another issue with syncBn?