Interaction of autograd with scatter_add_ and is_grads_batched

I am trying to decode a confusing interaction between is_grads_batched=True in autograd.grad and scatter_ operations. My loss function requires the derivatives of the output of the (multi-ouput) model so I take gradients in the forward pass:

batch_size, n_outputs = output.shape
eye_mat = torch.eye(n_outputs)
grad_output = eye_mat.unsqueeze(1).expand(n_outputs, batch_size, n_outputs).to(energy.device)

gradient = torch.autograd.grad(
    outputs=outputs,
    inputs=inputs,
    grad_outputs=grad_output,
    retain_graph=training # True for training, False for prediction
    create_graph=training,
    allow_unused=True,
    is_grads_batched=True,
)[0]
if gradient is None:
    raise RuntimeError("gradient is None")
gradient = gradient.transpose(0, 1).transpose(1, 2)
return -1 * gradient

I am experimenting with using is_grads_batched=True to get a speed boost with taking these derivatives for each output of the multi-output model. I get the following error when I set torch.autograd.set_anomaly_detect(True) (when this is not set I get no errors but the model clearly trains incorrectly on the derivatives).

  File "/project_root/scripts/run_train.py", line 609, in <module>
    main()
  File "/project_root/scripts/run_train.py", line 537, in main
    tools.train(
  File "/project_root/src/tools/train.py", line 99, in train
    _, opt_metrics = take_step(
  File "/project_root/src/tools/train.py", line 258, in take_step
    output = model(
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/project_root/src/modules/models.py", line 588, in forward
    inter_e = scatter_sum(
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/project_root/scripts/run_train.py", line 609, in <module>
    main()
  File "/project_root/scripts/run_train.py", line 537, in main
    tools.train(
  File "/project_root/src/tools/train.py", line 99, in train
    _, opt_metrics = take_step(
  File "/project_root/src/tools/train.py", line 258, in take_step
    output = model(
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/project_root/src/modules/models.py", line 603, in forward
    forces, virials, stress = get_outputs(
  File "/project_root/src/modules/utils.py", line 321, in get_outputs
    compute_forces(
  File "/project_root/src/modules/utils.py", line 47, in compute_forces
    gradient = torch.autograd.grad(
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 390, in grad
    result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/_vmap_internals.py", line 223, in wrapped
    batched_outputs = func(*batched_inputs)
  File "/home/usrname/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 380, in vjp
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Batching rule not implemented for aten::item. We could not generate a fallback.

The function scatter_sum is given by:

def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand_as(other)
    return src


@torch.jit.script
def scatter_sum(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    out: Optional[torch.Tensor] = None,
    dim_size: Optional[int] = None,
    reduce: str = "sum",
) -> torch.Tensor:
    assert reduce == "sum" 
    index = _broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)

I have tried replacing scatter_add_ with the out-of-place scatter_add and this makes no difference.

This issue does not occur when I use a version the first code snippet where is_grads_batched=False (and is replaced by a loop over outputs) and in that case the model trains correctly on the derivatives. Can anyone explain what is going on here?

is_grads_batched uses an older version of vmap with worse coverage. You might want to try torch.func Whirlwind Tour — PyTorch 2.1 documentation instead