RuntimeError: Function 'PowBackward0' returned nan values in its 0th output. when using scatter_reduce

Hi,

I have implemented a neural network (TAPAS by Google AI) which uses some scatter operations. Previously, I was using the torch_scatter library as native PyTorch didn’t support the scatter operations I needed.

Now that PyTorch 1.12 has been released, and it includes scatter_reduce, I’d like to update my code to remove the torch_scatter dependency.

I replaced the following code:

from torch_scatter import scatter

segment_means = scatter(
        src=flat_values,
        index=flat_index.indices.long(),
        dim=0,
        dim_size=int(flat_index.num_segments),
        reduce=segment_reduce_fn,
    )

with the following:

import torch

out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype)
segment_means = out.scatter_reduce(
        dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False)

This is giving me equivalent results for the forward pass.

However, when running a backward pass (loss.backward()), I’m getting the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [416]], which is output 0 of ScatterReduceBackward0, is at version 3; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

When debugging with torch.autograd.set_detect_anomaly(True), this gives me:

>           loss.backward()

tests/test_modeling_common.py:467: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
env/lib/python3.9/site-packages/torch/_tensor.py:396: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensors = (tensor(nan, grad_fn=<AddBackward0>),), grad_tensors = None, retain_graph = False, create_graph = False, grad_variables = None, inputs = ()

    def backward(
        tensors: _TensorOrTensors,
        grad_tensors: Optional[_TensorOrTensors] = None,
        retain_graph: Optional[bool] = None,
        create_graph: bool = False,
        grad_variables: Optional[_TensorOrTensors] = None,
        inputs: Optional[_TensorOrTensors] = None,
    ) -> None:
        r"""Computes the sum of gradients of given tensors with respect to graph
        leaves.
    
        The graph is differentiated using the chain rule. If any of ``tensors``
        are non-scalar (i.e. their data has more than one element) and require
        gradient, then the Jacobian-vector product would be computed, in this
        case the function additionally requires specifying ``grad_tensors``.
        It should be a sequence of matching length, that contains the "vector"
        in the Jacobian-vector product, usually the gradient of the differentiated
        function w.r.t. corresponding tensors (``None`` is an acceptable value for
        all tensors that don't need gradient tensors).
    
        This function accumulates gradients in the leaves - you might need to zero
        ``.grad`` attributes or set them to ``None`` before calling it.
        See :ref:`Default gradient layouts<default-grad-layouts>`
        for details on the memory layout of accumulated gradients.
    
        .. note::
            Using this method with ``create_graph=True`` will create a reference cycle
            between the parameter and its gradient which can cause a memory leak.
            We recommend using ``autograd.grad`` when creating the graph to avoid this.
            If you have to use this function, make sure to reset the ``.grad`` fields of your
            parameters to ``None`` after use to break the cycle and avoid the leak.
    
        .. note::
    
            If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
            in a user-specified CUDA stream context, see
            :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
    
        .. note::
    
            When ``inputs`` are provided and a given input is not a leaf,
            the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
            It is an implementation detail on which the user should not rely.
            See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
    
        Args:
            tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
                computed.
            grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
                the Jacobian-vector product, usually gradients w.r.t. each element of
                corresponding tensors. None values can be specified for scalar Tensors or
                ones that don't require grad. If a None value would be acceptable for all
                grad_tensors, then this argument is optional.
            retain_graph (bool, optional): If ``False``, the graph used to compute the grad
                will be freed. Note that in nearly all cases setting this option to ``True``
                is not needed and often can be worked around in a much more efficient
                way. Defaults to the value of ``create_graph``.
            create_graph (bool, optional): If ``True``, graph of the derivative will
                be constructed, allowing to compute higher order derivative products.
                Defaults to ``False``.
            inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient
                be will accumulated into ``.grad``. All other Tensors will be ignored. If
                not provided, the gradient is accumulated into all the leaf Tensors that
                were used to compute the attr::tensors.
        """
        if grad_variables is not None:
            warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
            if grad_tensors is None:
                grad_tensors = grad_variables
            else:
                raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
                                   "arguments both passed to backward(). Please only "
                                   "use 'grad_tensors'.")
        if inputs is not None and len(inputs) == 0:
            raise RuntimeError("'inputs' argument to backward() cannot be empty.")
    
        tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
        inputs = (inputs,) if isinstance(inputs, torch.Tensor) else \
            tuple(inputs) if inputs is not None else tuple()
    
        grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
        grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
        if retain_graph is None:
            retain_graph = create_graph
    
        # The reason we repeat same the comment below is that
        # some Python versions print out the first line of a multi-line function
        # calls in the traceback and some print out the last line
>       Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
            tensors, grad_tensors_, retain_graph, create_graph, inputs,
            allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
E       RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

Any help is greatly appreciated!