I am trying to decode a confusing interaction between is_grads_batched=True
in autograd.grad
and scatter_
operations. My loss function requires the derivatives of the output of the (multi-ouput) model so I take gradients in the forward pass:
batch_size, n_outputs = output.shape
eye_mat = torch.eye(n_outputs)
grad_output = eye_mat.unsqueeze(1).expand(n_outputs, batch_size, n_outputs).to(energy.device)
gradient = torch.autograd.grad(
outputs=outputs,
inputs=inputs,
grad_outputs=grad_output,
retain_graph=training # True for training, False for prediction
create_graph=training,
allow_unused=True,
is_grads_batched=True,
)[0]
if gradient is None:
raise RuntimeError("gradient is None")
gradient = gradient.transpose(0, 1).transpose(1, 2)
return -1 * gradient
I am experimenting with using is_grads_batched=True
to get a speed boost with taking these derivatives for each output of the multi-output model. I get the following error when I set torch.autograd.set_anomaly_detect(True)
(when this is not set I get no errors but the model clearly trains incorrectly on the derivatives).
File "/project_root/scripts/run_train.py", line 609, in <module>
main()
File "/project_root/scripts/run_train.py", line 537, in main
tools.train(
File "/project_root/src/tools/train.py", line 99, in train
_, opt_metrics = take_step(
File "/project_root/src/tools/train.py", line 258, in take_step
output = model(
File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/project_root/src/modules/models.py", line 588, in forward
inter_e = scatter_sum(
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/project_root/scripts/run_train.py", line 609, in <module>
main()
File "/project_root/scripts/run_train.py", line 537, in main
tools.train(
File "/project_root/src/tools/train.py", line 99, in train
_, opt_metrics = take_step(
File "/project_root/src/tools/train.py", line 258, in take_step
output = model(
File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/usrname/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/project_root/src/modules/models.py", line 603, in forward
forces, virials, stress = get_outputs(
File "/project_root/src/modules/utils.py", line 321, in get_outputs
compute_forces(
File "/project_root/src/modules/utils.py", line 47, in compute_forces
gradient = torch.autograd.grad(
File "/home/usrname/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 390, in grad
result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
File "/home/usrname/.local/lib/python3.9/site-packages/torch/_vmap_internals.py", line 223, in wrapped
batched_outputs = func(*batched_inputs)
File "/home/usrname/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 380, in vjp
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Batching rule not implemented for aten::item. We could not generate a fallback.
The function scatter_sum
is given by:
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src
@torch.jit.script
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum"
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
I have tried replacing scatter_add_
with the out-of-place scatter_add
and this makes no difference.
This issue does not occur when I use a version the first code snippet where is_grads_batched=False
(and is replaced by a loop over outputs) and in that case the model trains correctly on the derivatives. Can anyone explain what is going on here?